| | |
| | | def init_beam_search(self, |
| | | **kwargs, |
| | | ): |
| | | from funasr.models.scama.beam_search import BeamSearchScama |
| | | |
| | | from funasr.models.scama.beam_search import BeamSearchScamaStreaming |
| | | |
| | | |
| | | from funasr.models.transformer.scorers.ctc import CTCPrefixScorer |
| | | from funasr.models.transformer.scorers.length_bonus import LengthBonus |
| | | |
| | |
| | | scorers["ngram"] = ngram |
| | | |
| | | weights = dict( |
| | | decoder=1.0 - kwargs.get("decoding_ctc_weight"), |
| | | decoder=1.0 - kwargs.get("decoding_ctc_weight", 0.0), |
| | | ctc=kwargs.get("decoding_ctc_weight", 0.0), |
| | | lm=kwargs.get("lm_weight", 0.0), |
| | | ngram=kwargs.get("ngram_weight", 0.0), |
| | | length_bonus=kwargs.get("penalty", 0.0), |
| | | ) |
| | | beam_search = BeamSearchScama( |
| | | |
| | | beam_search = BeamSearchScamaStreaming( |
| | | beam_size=kwargs.get("beam_size", 2), |
| | | weights=weights, |
| | | scorers=scorers, |
| | |
| | | is_final=kwargs.get("is_final", False)) |
| | | if isinstance(encoder_out, tuple): |
| | | encoder_out = encoder_out[0] |
| | | |
| | | if "running_hyps" not in cache: |
| | | running_hyps = self.beam_search.init_hyp(encoder_out) |
| | | cache["running_hyps"] = running_hyps |
| | | |
| | | |
| | | # predictor |
| | | predictor_outs = self.calc_predictor_chunk(encoder_out, |
| | | encoder_out_lens, |
| | |
| | | |
| | | if torch.max(pre_token_length) < 1: |
| | | return [] |
| | | decoder_outs = self.cal_decoder_with_predictor_chunk(encoder_out, |
| | | encoder_out_lens, |
| | | pre_acoustic_embeds, |
| | | pre_token_length, |
| | | cache=cache |
| | | ) |
| | | decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1] |
| | | |
| | | maxlen = minlen = pre_token_length |
| | | if kwargs.get("is_final", False): |
| | | maxlen += kwargs.get("token_num_relax", 5) |
| | | minlen = max(0, minlen - kwargs.get("token_num_relax", 5)) |
| | | # c. Passed the encoder result and the beam search |
| | | nbest_hyps = self.beam_search( |
| | | x=encoder_out[0], scama_mask=None, pre_acoustic_embeds=pre_acoustic_embeds, maxlen=int(maxlen), minlen=int(minlen), cache=cache, |
| | | ) |
| | | |
| | | cache["running_hyps"] = nbest_hyps |
| | | nbest_hyps = nbest_hyps[: self.nbest] |
| | | |
| | | results = [] |
| | | b, n, d = decoder_out.size() |
| | | if isinstance(key[0], (list, tuple)): |
| | | key = key[0] |
| | | for i in range(b): |
| | | x = encoder_out[i, :encoder_out_lens[i], :] |
| | | am_scores = decoder_out[i, :pre_token_length[i], :] |
| | | if self.beam_search is not None: |
| | | nbest_hyps = self.beam_search( |
| | | x=x, am_scores=am_scores, maxlenratio=kwargs.get("maxlenratio", 0.0), |
| | | minlenratio=kwargs.get("minlenratio", 0.0) |
| | | ) |
| | | |
| | | nbest_hyps = nbest_hyps[: self.nbest] |
| | | for hyp in nbest_hyps: |
| | | # assert isinstance(hyp, (Hypothesis)), type(hyp) |
| | | |
| | | # remove sos/eos and get results |
| | | last_pos = -1 |
| | | if isinstance(hyp.yseq, list): |
| | | token_int = hyp.yseq[1:last_pos] |
| | | else: |
| | | |
| | | yseq = am_scores.argmax(dim=-1) |
| | | score = am_scores.max(dim=-1)[0] |
| | | score = torch.sum(score, dim=-1) |
| | | # pad with mask tokens to ensure compatibility with sos/eos tokens |
| | | yseq = torch.tensor( |
| | | [self.sos] + yseq.tolist() + [self.eos], device=yseq.device |
| | | ) |
| | | nbest_hyps = [Hypothesis(yseq=yseq, score=score)] |
| | | for nbest_idx, hyp in enumerate(nbest_hyps): |
| | | |
| | | # remove sos/eos and get results |
| | | last_pos = -1 |
| | | if isinstance(hyp.yseq, list): |
| | | token_int = hyp.yseq[1:last_pos] |
| | | else: |
| | | token_int = hyp.yseq[1:last_pos].tolist() |
| | | |
| | | token_int = hyp.yseq[1:last_pos].tolist() |
| | | |
| | | |
| | | # remove blank symbol id, which is assumed to be 0 |
| | | token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int)) |
| | | |
| | |
| | | return results |
| | | |
| | | def init_cache(self, cache: dict = {}, **kwargs): |
| | | device = kwargs.get("device", "cuda") |
| | | |
| | | chunk_size = kwargs.get("chunk_size", [0, 10, 5]) |
| | | encoder_chunk_look_back = kwargs.get("encoder_chunk_look_back", 0) |
| | | decoder_chunk_look_back = kwargs.get("decoder_chunk_look_back", 0) |
| | |
| | | |
| | | enc_output_size = kwargs["encoder_conf"]["output_size"] |
| | | feats_dims = kwargs["frontend_conf"]["n_mels"] * kwargs["frontend_conf"]["lfr_m"] |
| | | cache_encoder = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)), |
| | | "cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size, |
| | | |
| | | cache_encoder = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)).to(device=device), |
| | | "cif_alphas": torch.zeros((batch_size, 1)).to(device=device), "chunk_size": chunk_size, |
| | | "encoder_chunk_look_back": encoder_chunk_look_back, "last_chunk": False, "opt": None, |
| | | "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)), |
| | | "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)).to(device=device), |
| | | "tail_chunk": False} |
| | | cache["encoder"] = cache_encoder |
| | | |
| | |
| | | "chunk_size": chunk_size} |
| | | cache["decoder"] = cache_decoder |
| | | cache["frontend"] = {} |
| | | cache["prev_samples"] = torch.empty(0) |
| | | |
| | | |
| | | |
| | | cache["prev_samples"] = torch.empty(0).to(device=device) |
| | | |
| | | return cache |
| | | |
| | | def inference(self, |
| | |
| | | # init beamsearch |
| | | is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None |
| | | is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None |
| | | if self.beam_search is None and (is_use_lm or is_use_ctc): |
| | | |
| | | if self.beam_search is None: |
| | | |
| | | |
| | | logging.info("enable beam_search") |
| | | self.init_beam_search(**kwargs) |
| | | self.nbest = kwargs.get("nbest", 1) |