| | |
| | | |
| | | return loss_att, acc_att, None, None |
| | | |
| | | def init_beam_search( |
| | | self, |
| | | **kwargs, |
| | | ): |
| | | from .search import BeamSearch |
| | | |
| | | from funasr.models.transformer.scorers.length_bonus import LengthBonus |
| | | |
| | | # 1. Build ASR model |
| | | scorers = {} |
| | | |
| | | scorers.update( |
| | | decoder=self.model.decoder, |
| | | length_bonus=LengthBonus(self.vocab_size), |
| | | ) |
| | | |
| | | weights = dict( |
| | | decoder=1.0, |
| | | ctc=0.0, |
| | | lm=0.0, |
| | | ngram=0.0, |
| | | length_bonus=kwargs.get("penalty", 0.0), |
| | | ) |
| | | beam_search = BeamSearch( |
| | | beam_size=kwargs.get("beam_size", 5), |
| | | weights=weights, |
| | | scorers=scorers, |
| | | sos=None, |
| | | eos=None, |
| | | vocab_size=self.vocab_size, |
| | | token_list=None, |
| | | pre_beam_score_key="full", |
| | | ) |
| | | |
| | | self.beam_search = beam_search |
| | | |
| | | def inference( |
| | | self, |
| | | data_in, |
| | |
| | | ): |
| | | if kwargs.get("batch_size", 1) > 1: |
| | | raise NotImplementedError("batch decoding is not implemented") |
| | | |
| | | # init beamsearch |
| | | if not hasattr(self, "beam_search") or self.beam_search is None: |
| | | logging.info("enable beam_search") |
| | | self.init_beam_search(**kwargs) |
| | | self.nbest = kwargs.get("nbest", 1) |
| | | |
| | | if frontend is None and not hasattr(self, "frontend"): |
| | | frontend_class = tables.frontend_classes.get("WhisperFrontend") |
| | |
| | | task = [task] |
| | | task = "".join([f"<|{x}|>" for x in task]) |
| | | initial_prompt = kwargs.get("initial_prompt", f"<|startoftranscript|>{task}") |
| | | DecodingOptions["initial_prompt"] = initial_prompt |
| | | |
| | | language = DecodingOptions.get("language", None) |
| | | language = None if language == "auto" else language |
| | | DecodingOptions["language"] = language |
| | | |
| | | DecodingOptions["vocab_path"] = kwargs["tokenizer_conf"].get("vocab_path", None) |
| | | sos = f"{initial_prompt}<|{language}|>" if language is not None else initial_prompt |
| | | sos_int = tokenizer.encode(sos, allowed_special="all") |
| | | eos = kwargs.get("model_conf").get("eos") |
| | | eos_int = tokenizer.encode(eos, allowed_special="all") |
| | | self.beam_search.sos = sos_int |
| | | self.beam_search.eos = eos_int[0] |
| | | |
| | | if "without_timestamps" not in DecodingOptions: |
| | | DecodingOptions["without_timestamps"] = True |
| | | encoder_out, encoder_out_lens = self.encode( |
| | | speech[None, :, :].permute(0, 2, 1), speech_lengths |
| | | ) |
| | | |
| | | options = whisper.DecodingOptions(**DecodingOptions) |
| | | # c. Passed the encoder result and the beam search |
| | | nbest_hyps = self.beam_search( |
| | | x=encoder_out[0], |
| | | maxlenratio=kwargs.get("maxlenratio", 0.0), |
| | | minlenratio=kwargs.get("minlenratio", 0.0), |
| | | ) |
| | | |
| | | result = whisper.decode(self.model, speech, options) |
| | | text = f"{result.text}" |
| | | nbest_hyps = nbest_hyps[: self.nbest] |
| | | |
| | | results = [] |
| | | result_i = {"key": key[0], "text": text} |
| | | b, n, d = encoder_out.size() |
| | | for i in range(b): |
| | | |
| | | results.append(result_i) |
| | | for nbest_idx, hyp in enumerate(nbest_hyps): |
| | | ibest_writer = None |
| | | if kwargs.get("output_dir") is not None: |
| | | if not hasattr(self, "writer"): |
| | | self.writer = DatadirWriter(kwargs.get("output_dir")) |
| | | ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"] |
| | | |
| | | # remove sos/eos and get results |
| | | last_pos = -1 |
| | | if isinstance(hyp.yseq, list): |
| | | token_int = hyp.yseq[1:last_pos] |
| | | else: |
| | | token_int = hyp.yseq[1:last_pos].tolist() |
| | | |
| | | # # remove blank symbol id, which is assumed to be 0 |
| | | # token_int = list( |
| | | # filter( |
| | | # lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int |
| | | # ) |
| | | # ) |
| | | |
| | | # Change integer-ids to tokens |
| | | # token = tokenizer.ids2tokens(token_int) |
| | | text = tokenizer.decode(token_int) |
| | | |
| | | result_i = {"key": key[i], "text": text} |
| | | results.append(result_i) |
| | | |
| | | if ibest_writer is not None: |
| | | # ibest_writer["token"][key[i]] = " ".join(token) |
| | | ibest_writer["text"][key[i]] = text |
| | | |
| | | return results, meta_data |
| | | |