| | |
| | | from funasr.train_utils.device_funcs import force_gatherable |
| | | from . import whisper_lib as whisper |
| | | from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank |
| | | from funasr.utils.datadir_writer import DatadirWriter |
| | | |
| | | from funasr.register import tables |
| | | |
| | |
| | | |
| | | return loss_att, acc_att, None, None |
| | | |
| | | def init_beam_search( |
| | | self, |
| | | **kwargs, |
| | | ): |
| | | from .search import BeamSearch |
| | | |
| | | from funasr.models.transformer.scorers.length_bonus import LengthBonus |
| | | |
| | | # 1. Build ASR model |
| | | scorers = {} |
| | | |
| | | scorers.update( |
| | | decoder=self.model.decoder, |
| | | length_bonus=LengthBonus(self.vocab_size), |
| | | ) |
| | | |
| | | weights = dict( |
| | | decoder=1.0, |
| | | ctc=0.0, |
| | | lm=0.0, |
| | | ngram=0.0, |
| | | length_bonus=kwargs.get("penalty", 0.0), |
| | | ) |
| | | beam_search = BeamSearch( |
| | | beam_size=kwargs.get("beam_size", 5), |
| | | weights=weights, |
| | | scorers=scorers, |
| | | sos=None, |
| | | eos=None, |
| | | vocab_size=self.vocab_size, |
| | | token_list=None, |
| | | pre_beam_score_key="full", |
| | | ) |
| | | |
| | | self.beam_search = beam_search |
| | | |
| | | def inference( |
| | | self, |
| | | data_in, |
| | |
| | | ): |
| | | if kwargs.get("batch_size", 1) > 1: |
| | | raise NotImplementedError("batch decoding is not implemented") |
| | | |
| | | # init beamsearch |
| | | if not hasattr(self, "beam_search") or self.beam_search is None: |
| | | logging.info("enable beam_search") |
| | | self.init_beam_search(**kwargs) |
| | | self.nbest = kwargs.get("nbest", 1) |
| | | |
| | | if frontend is None and not hasattr(self, "frontend"): |
| | | frontend_class = tables.frontend_classes.get("WhisperFrontend") |
| | |
| | | task = [task] |
| | | task = "".join([f"<|{x}|>" for x in task]) |
| | | initial_prompt = kwargs.get("initial_prompt", f"<|startoftranscript|>{task}") |
| | | DecodingOptions["initial_prompt"] = initial_prompt |
| | | |
| | | language = DecodingOptions.get("language", None) |
| | | language = None if language == "auto" else language |
| | | DecodingOptions["language"] = language |
| | | |
| | | DecodingOptions["vocab_path"] = kwargs["tokenizer_conf"].get("vocab_path", None) |
| | | sos = f"{initial_prompt}<|{language}|>" if language is not None else initial_prompt |
| | | sos_int = tokenizer.encode(sos, allowed_special="all") |
| | | eos = kwargs.get("model_conf").get("eos") |
| | | eos_int = tokenizer.encode(eos, allowed_special="all") |
| | | self.beam_search.sos = sos_int |
| | | self.beam_search.eos = eos_int[0] |
| | | |
| | | if "without_timestamps" not in DecodingOptions: |
| | | DecodingOptions["without_timestamps"] = True |
| | | encoder_out, encoder_out_lens = self.encode( |
| | | speech[None, :, :].permute(0, 2, 1), speech_lengths |
| | | ) |
| | | |
| | | options = whisper.DecodingOptions(**DecodingOptions) |
| | | # c. Passed the encoder result and the beam search |
| | | nbest_hyps = self.beam_search( |
| | | x=encoder_out[0], |
| | | maxlenratio=kwargs.get("maxlenratio", 0.0), |
| | | minlenratio=kwargs.get("minlenratio", 0.0), |
| | | ) |
| | | |
| | | result = whisper.decode(self.model, speech, options) |
| | | text = f"{result.text}" |
| | | nbest_hyps = nbest_hyps[: self.nbest] |
| | | |
| | | results = [] |
| | | result_i = {"key": key[0], "text": text} |
| | | b, n, d = encoder_out.size() |
| | | for i in range(b): |
| | | |
| | | results.append(result_i) |
| | | for nbest_idx, hyp in enumerate(nbest_hyps): |
| | | ibest_writer = None |
| | | if kwargs.get("output_dir") is not None: |
| | | if not hasattr(self, "writer"): |
| | | self.writer = DatadirWriter(kwargs.get("output_dir")) |
| | | ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"] |
| | | |
| | | # remove sos/eos and get results |
| | | last_pos = -1 |
| | | if isinstance(hyp.yseq, list): |
| | | token_int = hyp.yseq[1:last_pos] |
| | | else: |
| | | token_int = hyp.yseq[1:last_pos].tolist() |
| | | |
| | | # # remove blank symbol id, which is assumed to be 0 |
| | | # token_int = list( |
| | | # filter( |
| | | # lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int |
| | | # ) |
| | | # ) |
| | | |
| | | # Change integer-ids to tokens |
| | | # token = tokenizer.ids2tokens(token_int) |
| | | text = tokenizer.decode(token_int) |
| | | |
| | | result_i = {"key": key[i], "text": text} |
| | | results.append(result_i) |
| | | |
| | | if ibest_writer is not None: |
| | | # ibest_writer["token"][key[i]] = " ".join(token) |
| | | ibest_writer["text"][key[i]] = text |
| | | |
| | | return results, meta_data |
| | | |
| | |
| | | # decoder |
| | | del model.decoder |
| | | decoder = kwargs.get("decoder", "SenseVoiceDecoder") |
| | | decoder_conf = kwargs.get("decoder_conf", {}) |
| | | decoder_class = tables.decoder_classes.get(decoder) |
| | | decoder = decoder_class( |
| | | vocab_size=dims.n_vocab, |
| | | encoder_output_size=dims.n_audio_state, |
| | | **decoder_conf, |
| | | n_vocab=dims.n_vocab, |
| | | n_ctx=dims.n_text_ctx, |
| | | n_state=dims.n_text_state, |
| | | n_head=dims.n_text_head, |
| | | n_layer=dims.n_text_layer, |
| | | **kwargs.get("decoder_conf"), |
| | | ) |
| | | model.decoder = decoder |
| | | |
| | |
| | | |
| | | self.activation_checkpoint = kwargs.get("activation_checkpoint", False) |
| | | self.ignore_id = kwargs.get("ignore_id", -1) |
| | | self.vocab_size = kwargs.get("vocab_size", -1) |
| | | self.vocab_size = dims.n_vocab |
| | | self.length_normalized_loss = kwargs.get("length_normalized_loss", True) |
| | | self.criterion_att = LabelSmoothingLoss( |
| | | size=self.vocab_size, |
| | |
| | | |
| | | return loss_att, acc_att, None, None |
| | | |
| | | def init_beam_search( |
| | | self, |
| | | **kwargs, |
| | | ): |
| | | from .search import BeamSearch |
| | | |
| | | from funasr.models.transformer.scorers.length_bonus import LengthBonus |
| | | |
| | | # 1. Build ASR model |
| | | scorers = {} |
| | | |
| | | scorers.update( |
| | | decoder=self.model.decoder, |
| | | length_bonus=LengthBonus(self.vocab_size), |
| | | ) |
| | | |
| | | weights = dict( |
| | | decoder=1.0, |
| | | ctc=0.0, |
| | | lm=0.0, |
| | | ngram=0.0, |
| | | length_bonus=kwargs.get("penalty", 0.0), |
| | | ) |
| | | beam_search = BeamSearch( |
| | | beam_size=kwargs.get("beam_size", 5), |
| | | weights=weights, |
| | | scorers=scorers, |
| | | sos=None, |
| | | eos=None, |
| | | vocab_size=self.vocab_size, |
| | | token_list=None, |
| | | pre_beam_score_key="full", |
| | | ) |
| | | |
| | | self.beam_search = beam_search |
| | | |
| | | def inference( |
| | | self, |
| | | data_in, |
| | |
| | | ): |
| | | if kwargs.get("batch_size", 1) > 1: |
| | | raise NotImplementedError("batch decoding is not implemented") |
| | | |
| | | # init beamsearch |
| | | if not hasattr(self, "beam_search") or self.beam_search is None: |
| | | logging.info("enable beam_search") |
| | | self.init_beam_search(**kwargs) |
| | | self.nbest = kwargs.get("nbest", 1) |
| | | |
| | | if frontend is None and not hasattr(self, "frontend"): |
| | | frontend_class = tables.frontend_classes.get("WhisperFrontend") |
| | |
| | | task = [task] |
| | | task = "".join([f"<|{x}|>" for x in task]) |
| | | initial_prompt = kwargs.get("initial_prompt", f"<|startoftranscript|>{task}") |
| | | DecodingOptions["initial_prompt"] = initial_prompt |
| | | |
| | | language = DecodingOptions.get("language", None) |
| | | language = None if language == "auto" else language |
| | | DecodingOptions["language"] = language |
| | | |
| | | DecodingOptions["vocab_path"] = kwargs["tokenizer_conf"].get("vocab_path", None) |
| | | sos = f"{initial_prompt}<|{language}|>" if language is not None else initial_prompt |
| | | sos_int = tokenizer.encode(sos, allowed_special="all") |
| | | eos = kwargs.get("model_conf").get("eos") |
| | | eos_int = tokenizer.encode(eos, allowed_special="all") |
| | | self.beam_search.sos = sos_int |
| | | self.beam_search.eos = eos_int[0] |
| | | |
| | | if "without_timestamps" not in DecodingOptions: |
| | | DecodingOptions["without_timestamps"] = True |
| | | encoder_out, encoder_out_lens = self.encode( |
| | | speech[None, :, :].permute(0, 2, 1), speech_lengths |
| | | ) |
| | | |
| | | options = whisper.DecodingOptions(**DecodingOptions) |
| | | # c. Passed the encoder result and the beam search |
| | | nbest_hyps = self.beam_search( |
| | | x=encoder_out[0], |
| | | maxlenratio=kwargs.get("maxlenratio", 0.0), |
| | | minlenratio=kwargs.get("minlenratio", 0.0), |
| | | ) |
| | | |
| | | result = whisper.decode(self.model, speech, options) |
| | | text = f"{result.text}" |
| | | nbest_hyps = nbest_hyps[: self.nbest] |
| | | |
| | | results = [] |
| | | result_i = {"key": key[0], "text": text} |
| | | b, n, d = encoder_out.size() |
| | | for i in range(b): |
| | | |
| | | results.append(result_i) |
| | | for nbest_idx, hyp in enumerate(nbest_hyps): |
| | | ibest_writer = None |
| | | if kwargs.get("output_dir") is not None: |
| | | if not hasattr(self, "writer"): |
| | | self.writer = DatadirWriter(kwargs.get("output_dir")) |
| | | ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"] |
| | | |
| | | # remove sos/eos and get results |
| | | last_pos = -1 |
| | | if isinstance(hyp.yseq, list): |
| | | token_int = hyp.yseq[1:last_pos] |
| | | else: |
| | | token_int = hyp.yseq[1:last_pos].tolist() |
| | | |
| | | # # remove blank symbol id, which is assumed to be 0 |
| | | # token_int = list( |
| | | # filter( |
| | | # lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int |
| | | # ) |
| | | # ) |
| | | |
| | | # Change integer-ids to tokens |
| | | # token = tokenizer.ids2tokens(token_int) |
| | | text = tokenizer.decode(token_int) |
| | | |
| | | result_i = {"key": key[i], "text": text} |
| | | results.append(result_i) |
| | | |
| | | if ibest_writer is not None: |
| | | # ibest_writer["token"][key[i]] = " ".join(token) |
| | | ibest_writer["text"][key[i]] = text |
| | | |
| | | return results, meta_data |