From 37d7764ecf0e8cc1a14f59b8b9cd1c914da8b005 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期日, 21 一月 2024 21:06:52 +0800
Subject: [PATCH] Funasr1.0 (#1277)
---
funasr/models/scama/beam_search.py | 467 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
1 files changed, 466 insertions(+), 1 deletions(-)
diff --git a/funasr/models/scama/beam_search.py b/funasr/models/scama/beam_search.py
index 8f0d751..b8aa876 100644
--- a/funasr/models/scama/beam_search.py
+++ b/funasr/models/scama/beam_search.py
@@ -11,7 +11,7 @@
import torch
-from funasr.metrics import end_detect
+from funasr.metrics.common import end_detect
from funasr.models.transformer.scorers.scorer_interface import PartialScorerInterface
from funasr.models.transformer.scorers.scorer_interface import ScorerInterface
@@ -494,3 +494,468 @@
else:
remained_hyps.append(hyp)
return remained_hyps
+
+class BeamSearchScamaStreaming(torch.nn.Module):
+ """Beam search implementation."""
+
+ def __init__(
+ self,
+ scorers: Dict[str, ScorerInterface],
+ weights: Dict[str, float],
+ beam_size: int,
+ vocab_size: int,
+ sos: int,
+ eos: int,
+ token_list: List[str] = None,
+ pre_beam_ratio: float = 1.5,
+ pre_beam_score_key: str = None,
+ ):
+ """Initialize beam search.
+
+ Args:
+ scorers (dict[str, ScorerInterface]): Dict of decoder modules
+ e.g., Decoder, CTCPrefixScorer, LM
+ The scorer will be ignored if it is `None`
+ weights (dict[str, float]): Dict of weights for each scorers
+ The scorer will be ignored if its weight is 0
+ beam_size (int): The number of hypotheses kept during search
+ vocab_size (int): The number of vocabulary
+ sos (int): Start of sequence id
+ eos (int): End of sequence id
+ token_list (list[str]): List of tokens for debug log
+ pre_beam_score_key (str): key of scores to perform pre-beam search
+ pre_beam_ratio (float): beam size in the pre-beam search
+ will be `int(pre_beam_ratio * beam_size)`
+
+ """
+ super().__init__()
+ # set scorers
+ self.weights = weights
+ self.scorers = dict()
+ self.full_scorers = dict()
+ self.part_scorers = dict()
+ # this module dict is required for recursive cast
+ # `self.to(device, dtype)` in `recog.py`
+ self.nn_dict = torch.nn.ModuleDict()
+ for k, v in scorers.items():
+ w = weights.get(k, 0)
+ if w == 0 or v is None:
+ continue
+ assert isinstance(
+ v, ScorerInterface
+ ), f"{k} ({type(v)}) does not implement ScorerInterface"
+ self.scorers[k] = v
+ if isinstance(v, PartialScorerInterface):
+ self.part_scorers[k] = v
+ else:
+ self.full_scorers[k] = v
+ if isinstance(v, torch.nn.Module):
+ self.nn_dict[k] = v
+
+ # set configurations
+ self.sos = sos
+ self.eos = eos
+ self.token_list = token_list
+ self.pre_beam_size = int(pre_beam_ratio * beam_size)
+ self.beam_size = beam_size
+ self.n_vocab = vocab_size
+ if (
+ pre_beam_score_key is not None
+ and pre_beam_score_key != "full"
+ and pre_beam_score_key not in self.full_scorers
+ ):
+ raise KeyError(f"{pre_beam_score_key} is not found in {self.full_scorers}")
+ self.pre_beam_score_key = pre_beam_score_key
+ self.do_pre_beam = (
+ self.pre_beam_score_key is not None
+ and self.pre_beam_size < self.n_vocab
+ and len(self.part_scorers) > 0
+ )
+
+ def init_hyp(self, x) -> List[Hypothesis]:
+ """Get an initial hypothesis data.
+
+ Args:
+ x (torch.Tensor): The encoder output feature
+
+ Returns:
+ Hypothesis: The initial hypothesis.
+
+ """
+ init_states = dict()
+ init_scores = dict()
+ for k, d in self.scorers.items():
+ init_states[k] = d.init_state(x)
+ init_scores[k] = 0.0
+ return [
+ Hypothesis(
+ score=0.0,
+ scores=init_scores,
+ states=init_states,
+ yseq=torch.tensor([self.sos], device=x.device),
+ )
+ ]
+
+ @staticmethod
+ def append_token(xs: torch.Tensor, x: int) -> torch.Tensor:
+ """Append new token to prefix tokens.
+
+ Args:
+ xs (torch.Tensor): The prefix token
+ x (int): The new token to append
+
+ Returns:
+ torch.Tensor: New tensor contains: xs + [x] with xs.dtype and xs.device
+
+ """
+ x = torch.tensor([x], dtype=xs.dtype, device=xs.device)
+ return torch.cat((xs, x))
+
+ def score_full(
+ self, hyp: Hypothesis,
+ x: torch.Tensor,
+ x_mask: torch.Tensor = None,
+ pre_acoustic_embeds: torch.Tensor = None,
+ cache: dict={},
+ ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
+ """Score new hypothesis by `self.full_scorers`.
+
+ Args:
+ hyp (Hypothesis): Hypothesis with prefix tokens to score
+ x (torch.Tensor): Corresponding input feature
+
+ Returns:
+ Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
+ score dict of `hyp` that has string keys of `self.full_scorers`
+ and tensor score values of shape: `(self.n_vocab,)`,
+ and state dict that has string keys
+ and state values of `self.full_scorers`
+
+ """
+ scores = dict()
+ states = dict()
+ for k, d in self.full_scorers.items():
+ scores[k], states[k] = d.score(hyp.yseq, hyp.states[k], x, x_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds, cache=cache)
+ return scores, states
+
+ def score_partial(
+ self, hyp: Hypothesis, ids: torch.Tensor, x: torch.Tensor
+ ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
+ """Score new hypothesis by `self.part_scorers`.
+
+ Args:
+ hyp (Hypothesis): Hypothesis with prefix tokens to score
+ ids (torch.Tensor): 1D tensor of new partial tokens to score
+ x (torch.Tensor): Corresponding input feature
+
+ Returns:
+ Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
+ score dict of `hyp` that has string keys of `self.part_scorers`
+ and tensor score values of shape: `(len(ids),)`,
+ and state dict that has string keys
+ and state values of `self.part_scorers`
+
+ """
+ scores = dict()
+ states = dict()
+ for k, d in self.part_scorers.items():
+ scores[k], states[k] = d.score_partial(hyp.yseq, ids, hyp.states[k], x)
+ return scores, states
+
+ def beam(
+ self, weighted_scores: torch.Tensor, ids: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Compute topk full token ids and partial token ids.
+
+ Args:
+ weighted_scores (torch.Tensor): The weighted sum scores for each tokens.
+ Its shape is `(self.n_vocab,)`.
+ ids (torch.Tensor): The partial token ids to compute topk
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]:
+ The topk full token ids and partial token ids.
+ Their shapes are `(self.beam_size,)`
+
+ """
+ # no pre beam performed
+ if weighted_scores.size(0) == ids.size(0):
+ top_ids = weighted_scores.topk(self.beam_size)[1]
+ return top_ids, top_ids
+
+ # mask pruned in pre-beam not to select in topk
+ tmp = weighted_scores[ids]
+ weighted_scores[:] = -float("inf")
+ weighted_scores[ids] = tmp
+ top_ids = weighted_scores.topk(self.beam_size)[1]
+ local_ids = weighted_scores[ids].topk(self.beam_size)[1]
+ return top_ids, local_ids
+
+ @staticmethod
+ def merge_scores(
+ prev_scores: Dict[str, float],
+ next_full_scores: Dict[str, torch.Tensor],
+ full_idx: int,
+ next_part_scores: Dict[str, torch.Tensor],
+ part_idx: int,
+ ) -> Dict[str, torch.Tensor]:
+ """Merge scores for new hypothesis.
+
+ Args:
+ prev_scores (Dict[str, float]):
+ The previous hypothesis scores by `self.scorers`
+ next_full_scores (Dict[str, torch.Tensor]): scores by `self.full_scorers`
+ full_idx (int): The next token id for `next_full_scores`
+ next_part_scores (Dict[str, torch.Tensor]):
+ scores of partial tokens by `self.part_scorers`
+ part_idx (int): The new token id for `next_part_scores`
+
+ Returns:
+ Dict[str, torch.Tensor]: The new score dict.
+ Its keys are names of `self.full_scorers` and `self.part_scorers`.
+ Its values are scalar tensors by the scorers.
+
+ """
+ new_scores = dict()
+ for k, v in next_full_scores.items():
+ new_scores[k] = prev_scores[k] + v[full_idx]
+ for k, v in next_part_scores.items():
+ new_scores[k] = prev_scores[k] + v[part_idx]
+ return new_scores
+
+ def merge_states(self, states: Any, part_states: Any, part_idx: int) -> Any:
+ """Merge states for new hypothesis.
+
+ Args:
+ states: states of `self.full_scorers`
+ part_states: states of `self.part_scorers`
+ part_idx (int): The new token id for `part_scores`
+
+ Returns:
+ Dict[str, torch.Tensor]: The new score dict.
+ Its keys are names of `self.full_scorers` and `self.part_scorers`.
+ Its values are states of the scorers.
+
+ """
+ new_states = dict()
+ for k, v in states.items():
+ new_states[k] = v
+ for k, d in self.part_scorers.items():
+ new_states[k] = d.select_state(part_states[k], part_idx)
+ return new_states
+
+ def search(
+ self, running_hyps: List[Hypothesis],
+ x: torch.Tensor,
+ x_mask: torch.Tensor = None,
+ pre_acoustic_embeds: torch.Tensor = None,
+ cache: dict={},
+ ) -> List[Hypothesis]:
+ """Search new tokens for running hypotheses and encoded speech x.
+
+ Args:
+ running_hyps (List[Hypothesis]): Running hypotheses on beam
+ x (torch.Tensor): Encoded speech feature (T, D)
+
+ Returns:
+ List[Hypotheses]: Best sorted hypotheses
+
+ """
+ best_hyps = []
+ part_ids = torch.arange(self.n_vocab, device=x.device) # no pre-beam
+ for hyp in running_hyps:
+ # scoring
+ weighted_scores = torch.zeros(self.n_vocab, dtype=x.dtype, device=x.device)
+ scores, states = self.score_full(hyp, x, x_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds, cache=cache)
+ for k in self.full_scorers:
+ weighted_scores += self.weights[k] * scores[k]
+ # partial scoring
+ if self.do_pre_beam:
+ pre_beam_scores = (
+ weighted_scores
+ if self.pre_beam_score_key == "full"
+ else scores[self.pre_beam_score_key]
+ )
+ part_ids = torch.topk(pre_beam_scores, self.pre_beam_size)[1]
+ part_scores, part_states = self.score_partial(hyp, part_ids, x)
+ for k in self.part_scorers:
+ weighted_scores[part_ids] += self.weights[k] * part_scores[k]
+ # add previous hyp score
+ weighted_scores += hyp.score
+
+ # update hyps
+ for j, part_j in zip(*self.beam(weighted_scores, part_ids)):
+ # will be (2 x beam at most)
+ best_hyps.append(
+ Hypothesis(
+ score=weighted_scores[j],
+ yseq=self.append_token(hyp.yseq, j),
+ scores=self.merge_scores(
+ hyp.scores, scores, j, part_scores, part_j
+ ),
+ states=self.merge_states(states, part_states, part_j),
+ )
+ )
+
+ # sort and prune 2 x beam -> beam
+ best_hyps = sorted(best_hyps, key=lambda x: x.score, reverse=True)[
+ : min(len(best_hyps), self.beam_size)
+ ]
+ return best_hyps
+
+ def forward(
+ self, x: torch.Tensor,
+ scama_mask: torch.Tensor = None,
+ pre_acoustic_embeds: torch.Tensor = None,
+ maxlenratio: float = 0.0,
+ minlenratio: float = 0.0,
+ maxlen: int = None,
+ minlen: int = 0,
+ cache:dict={},
+ ) -> List[Hypothesis]:
+ """Perform beam search.
+
+ Args:
+ x (torch.Tensor): Encoded speech feature (T, D)
+ maxlenratio (float): Input length ratio to obtain max output length.
+ If maxlenratio=0.0 (default), it uses a end-detect function
+ to automatically find maximum hypothesis lengths
+ If maxlenratio<0.0, its absolute value is interpreted
+ as a constant max output length.
+ minlenratio (float): Input length ratio to obtain min output length.
+
+ Returns:
+ list[Hypothesis]: N-best decoding results
+
+ """
+ if maxlen is None:
+ # set length bounds
+ if maxlenratio == 0:
+ maxlen = x.shape[0]
+ elif maxlenratio < 0:
+ maxlen = -1 * int(maxlenratio)
+ else:
+ maxlen = max(1, int(maxlenratio * x.size(0)))
+ minlen = int(minlenratio * x.size(0))
+
+ logging.info("decoder input length: " + str(x.shape[0]))
+ logging.info("max output length: " + str(maxlen))
+ logging.info("min output length: " + str(minlen))
+
+ # main loop of prefix search
+ # running_hyps = self.init_hyp(x)
+ running_hyps = cache["running_hyps"]
+ ended_hyps = []
+ for i in range(maxlen):
+ logging.debug("position " + str(i))
+ mask_enc = None
+ # if scama_mask is not None:
+ # token_num_predictor = scama_mask.size(1)
+ # token_id_slice = min(i, token_num_predictor-1)
+ # mask_enc = scama_mask[:, token_id_slice:token_id_slice+1, :]
+ # # if mask_enc.size(1) == 0:
+ # # mask_enc = scama_mask[:, -2:-1, :]
+ # # # mask_enc = torch.zeros_like(mask_enc)
+ pre_acoustic_embeds_cur = None
+ if pre_acoustic_embeds is not None:
+ b, t, d = pre_acoustic_embeds.size()
+ pad = torch.zeros((b, 1, d), dtype=pre_acoustic_embeds.dtype).to(device=pre_acoustic_embeds.device)
+ pre_acoustic_embeds = torch.cat((pre_acoustic_embeds, pad), dim=1)
+ token_id_slice = min(i, t)
+ pre_acoustic_embeds_cur = pre_acoustic_embeds[:, token_id_slice:token_id_slice+1, :]
+
+ best = self.search(running_hyps, x, x_mask=mask_enc, pre_acoustic_embeds=pre_acoustic_embeds_cur, cache=cache["decoder"])
+ # post process of one iteration
+ running_hyps = self.post_process(i, maxlen, maxlenratio, best, ended_hyps)
+ # end detection
+ if maxlenratio == 0.0 and end_detect([h.asdict() for h in ended_hyps], i):
+ logging.info(f"end detected at {i}")
+ break
+ if len(running_hyps) == 0:
+ logging.info("no hypothesis. Finish decoding.")
+ break
+ else:
+ logging.debug(f"remained hypotheses: {len(running_hyps)}")
+
+ nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True)
+ # check the number of hypotheses reaching to eos
+ if len(nbest_hyps) == 0:
+ logging.warning(
+ "there is no N-best results, perform recognition "
+ "again with smaller minlenratio."
+ )
+ return (
+ []
+ if minlenratio < 0.1
+ else self.forward(x, maxlenratio, max(0.0, minlenratio - 0.1))
+ )
+
+ # report the best result
+ for x in nbest_hyps:
+ yseq = "".join([self.token_list[x] for x in x.yseq])
+ logging.debug("nbest: y: {}, yseq: {}, score: {}".format(x.yseq, yseq, x.score))
+ best = nbest_hyps[0]
+ for k, v in best.scores.items():
+ logging.info(
+ f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}"
+ )
+ logging.info(f"total log probability: {best.score:.2f}")
+ logging.info(f"normalized log probability: {best.score / len(best.yseq):.2f}")
+ logging.info(f"total number of ended hypotheses: {len(nbest_hyps)}")
+ if self.token_list is not None:
+ logging.info(
+ "best hypo: "
+ + "".join([self.token_list[x] for x in best.yseq[1:-1]])
+ + "\n"
+ )
+ return nbest_hyps
+
+ def post_process(
+ self,
+ i: int,
+ maxlen: int,
+ maxlenratio: float,
+ running_hyps: List[Hypothesis],
+ ended_hyps: List[Hypothesis],
+ ) -> List[Hypothesis]:
+ """Perform post-processing of beam search iterations.
+
+ Args:
+ i (int): The length of hypothesis tokens.
+ maxlen (int): The maximum length of tokens in beam search.
+ maxlenratio (int): The maximum length ratio in beam search.
+ running_hyps (List[Hypothesis]): The running hypotheses in beam search.
+ ended_hyps (List[Hypothesis]): The ended hypotheses in beam search.
+
+ Returns:
+ List[Hypothesis]: The new running hypotheses.
+
+ """
+ logging.debug(f"the number of running hypotheses: {len(running_hyps)}")
+ if self.token_list is not None:
+ logging.debug(
+ "best hypo: "
+ + "".join([self.token_list[x] for x in running_hyps[0].yseq[1:]])
+ )
+ # add eos in the final loop to avoid that there are no ended hyps
+ if i == maxlen - 1:
+ logging.info("adding <eos> in the last position in the loop")
+ running_hyps = [
+ h._replace(yseq=self.append_token(h.yseq, self.eos))
+ for h in running_hyps
+ ]
+
+ # add ended hypotheses to a final list, and removed them from current hypotheses
+ # (this will be a problem, number of hyps < beam)
+ remained_hyps = []
+ for hyp in running_hyps:
+ if hyp.yseq[-1] == self.eos:
+ # e.g., Word LM needs to add final <eos> score
+ for k, d in chain(self.full_scorers.items(), self.part_scorers.items()):
+ s = d.final_score(hyp.states[k])
+ hyp.scores[k] += s
+ hyp = hyp._replace(score=hyp.score + self.weights[k] * s)
+ ended_hyps.append(hyp)
+ else:
+ remained_hyps.append(hyp)
+ return remained_hyps
--
Gitblit v1.9.1