| | |
| | | )._asdict() |
| | | |
| | | |
| | | |
| | | class BeamSearchScama(torch.nn.Module): |
| | | """Beam search implementation.""" |
| | | |
| | |
| | | return torch.cat((xs, x)) |
| | | |
| | | def score_full( |
| | | self, hyp: Hypothesis, |
| | | self, |
| | | hyp: Hypothesis, |
| | | x: torch.Tensor, |
| | | x_mask: torch.Tensor = None, |
| | | pre_acoustic_embeds: torch.Tensor = None, |
| | |
| | | scores = dict() |
| | | states = dict() |
| | | for k, d in self.full_scorers.items(): |
| | | scores[k], states[k] = d.score(hyp.yseq, hyp.states[k], x, x_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds) |
| | | scores[k], states[k] = d.score( |
| | | hyp.yseq, hyp.states[k], x, x_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds |
| | | ) |
| | | return scores, states |
| | | |
| | | def score_partial( |
| | |
| | | return new_states |
| | | |
| | | def search( |
| | | self, running_hyps: List[Hypothesis], |
| | | self, |
| | | running_hyps: List[Hypothesis], |
| | | x: torch.Tensor, |
| | | x_mask: torch.Tensor = None, |
| | | pre_acoustic_embeds: torch.Tensor = None, |
| | |
| | | for hyp in running_hyps: |
| | | # scoring |
| | | weighted_scores = torch.zeros(self.n_vocab, dtype=x.dtype, device=x.device) |
| | | scores, states = self.score_full(hyp, x, x_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds) |
| | | scores, states = self.score_full( |
| | | hyp, x, x_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds |
| | | ) |
| | | for k in self.full_scorers: |
| | | weighted_scores += self.weights[k] * scores[k] |
| | | # partial scoring |
| | |
| | | Hypothesis( |
| | | score=weighted_scores[j], |
| | | yseq=self.append_token(hyp.yseq, j), |
| | | scores=self.merge_scores( |
| | | hyp.scores, scores, j, part_scores, part_j |
| | | ), |
| | | scores=self.merge_scores(hyp.scores, scores, j, part_scores, part_j), |
| | | states=self.merge_states(states, part_states, part_j), |
| | | ) |
| | | ) |
| | |
| | | return best_hyps |
| | | |
| | | def forward( |
| | | self, x: torch.Tensor, |
| | | self, |
| | | x: torch.Tensor, |
| | | scama_mask: torch.Tensor = None, |
| | | pre_acoustic_embeds: torch.Tensor = None, |
| | | maxlenratio: float = 0.0, |
| | |
| | | mask_enc = None |
| | | if scama_mask is not None: |
| | | token_num_predictor = scama_mask.size(1) |
| | | token_id_slice = min(i, token_num_predictor-1) |
| | | mask_enc = scama_mask[:, token_id_slice:token_id_slice+1, :] |
| | | token_id_slice = min(i, token_num_predictor - 1) |
| | | mask_enc = scama_mask[:, token_id_slice : token_id_slice + 1, :] |
| | | # if mask_enc.size(1) == 0: |
| | | # mask_enc = scama_mask[:, -2:-1, :] |
| | | # # mask_enc = torch.zeros_like(mask_enc) |
| | | pre_acoustic_embeds_cur = None |
| | | if pre_acoustic_embeds is not None: |
| | | b, t, d = pre_acoustic_embeds.size() |
| | | pad = torch.zeros((b, 1, d), dtype=pre_acoustic_embeds.dtype).to(device=pre_acoustic_embeds.device) |
| | | pad = torch.zeros((b, 1, d), dtype=pre_acoustic_embeds.dtype).to( |
| | | device=pre_acoustic_embeds.device |
| | | ) |
| | | pre_acoustic_embeds = torch.cat((pre_acoustic_embeds, pad), dim=1) |
| | | token_id_slice = min(i, t) |
| | | pre_acoustic_embeds_cur = pre_acoustic_embeds[:, token_id_slice:token_id_slice+1, :] |
| | | pre_acoustic_embeds_cur = pre_acoustic_embeds[ |
| | | :, token_id_slice : token_id_slice + 1, : |
| | | ] |
| | | |
| | | best = self.search(running_hyps, x, x_mask=mask_enc, pre_acoustic_embeds=pre_acoustic_embeds_cur) |
| | | best = self.search( |
| | | running_hyps, x, x_mask=mask_enc, pre_acoustic_embeds=pre_acoustic_embeds_cur |
| | | ) |
| | | # post process of one iteration |
| | | running_hyps = self.post_process(i, maxlen, maxlenratio, best, ended_hyps) |
| | | # end detection |
| | |
| | | # check the number of hypotheses reaching to eos |
| | | if len(nbest_hyps) == 0: |
| | | logging.warning( |
| | | "there is no N-best results, perform recognition " |
| | | "again with smaller minlenratio." |
| | | "there is no N-best results, perform recognition " "again with smaller minlenratio." |
| | | ) |
| | | return ( |
| | | [] |
| | |
| | | logging.debug("nbest: y: {}, yseq: {}, score: {}".format(x.yseq, yseq, x.score)) |
| | | best = nbest_hyps[0] |
| | | for k, v in best.scores.items(): |
| | | logging.info( |
| | | f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}" |
| | | ) |
| | | logging.info(f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}") |
| | | logging.info(f"total log probability: {best.score:.2f}") |
| | | logging.info(f"normalized log probability: {best.score / len(best.yseq):.2f}") |
| | | logging.info(f"total number of ended hypotheses: {len(nbest_hyps)}") |
| | | if self.token_list is not None: |
| | | logging.info( |
| | | "best hypo: " |
| | | + "".join([self.token_list[x] for x in best.yseq[1:-1]]) |
| | | + "\n" |
| | | "best hypo: " + "".join([self.token_list[x] for x in best.yseq[1:-1]]) + "\n" |
| | | ) |
| | | return nbest_hyps |
| | | |
| | |
| | | logging.debug(f"the number of running hypotheses: {len(running_hyps)}") |
| | | if self.token_list is not None: |
| | | logging.debug( |
| | | "best hypo: " |
| | | + "".join([self.token_list[x] for x in running_hyps[0].yseq[1:]]) |
| | | "best hypo: " + "".join([self.token_list[x] for x in running_hyps[0].yseq[1:]]) |
| | | ) |
| | | # add eos in the final loop to avoid that there are no ended hyps |
| | | if i == maxlen - 1: |
| | | logging.info("adding <eos> in the last position in the loop") |
| | | running_hyps = [ |
| | | h._replace(yseq=self.append_token(h.yseq, self.eos)) |
| | | for h in running_hyps |
| | | h._replace(yseq=self.append_token(h.yseq, self.eos)) for h in running_hyps |
| | | ] |
| | | |
| | | # add ended hypotheses to a final list, and removed them from current hypotheses |