| | |
| | | from funasr.train_utils.device_funcs import force_gatherable |
| | | from . import whisper_lib as whisper |
| | | from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank |
| | | from funasr.utils.datadir_writer import DatadirWriter |
| | | |
| | | from funasr.register import tables |
| | | |
| | |
| | | stats = {} |
| | | |
| | | # 1. Forward decoder |
| | | # ys_pad: [sos, task, lid, text, eos] |
| | | decoder_out = self.model.decoder( |
| | | x=ys_pad, xa=encoder_out, hlens=encoder_out_lens, ys_in_lens=ys_pad_lens |
| | | ) |
| | | |
| | | # 2. Compute attention loss |
| | | mask = torch.ones_like(ys_pad) * (-1) |
| | | ys_pad_mask = (ys_pad * target_mask + mask * (1 - target_mask)).to(torch.int64) |
| | | ys_pad_mask[ys_pad_mask == 0] = -1 |
| | | mask = torch.ones_like(ys_pad) * (-1) # [sos, task, lid, text, eos]: [-1, -1, -1, -1] |
| | | ys_pad_mask = (ys_pad * target_mask + mask * (1 - target_mask)).to( |
| | | torch.int64 |
| | | ) # [sos, task, lid, text, eos]: [0, 0, 1, 1, 1] + [-1, -1, 0, 0, 0] |
| | | ys_pad_mask[ys_pad_mask == 0] = -1 # [-1, -1, lid, text, eos] |
| | | # decoder_out: [sos, task, lid, text] |
| | | # ys_pad_mask: [-1, lid, text, eos] |
| | | loss_att = self.criterion_att(decoder_out[:, :-1, :], ys_pad_mask[:, 1:]) |
| | | |
| | | with torch.no_grad(): |
| | |
| | | ) |
| | | |
| | | return loss_att, acc_att, None, None |
| | | |
| | | def init_beam_search( |
| | | self, |
| | | **kwargs, |
| | | ): |
| | | from .search import BeamSearch |
| | | |
| | | from funasr.models.transformer.scorers.length_bonus import LengthBonus |
| | | |
| | | # 1. Build ASR model |
| | | scorers = {} |
| | | |
| | | scorers.update( |
| | | decoder=self.model.decoder, |
| | | length_bonus=LengthBonus(self.vocab_size), |
| | | ) |
| | | |
| | | weights = dict( |
| | | decoder=1.0, |
| | | ctc=0.0, |
| | | lm=0.0, |
| | | ngram=0.0, |
| | | length_bonus=kwargs.get("penalty", 0.0), |
| | | ) |
| | | beam_search = BeamSearch( |
| | | beam_size=kwargs.get("beam_size", 5), |
| | | weights=weights, |
| | | scorers=scorers, |
| | | sos=None, |
| | | eos=None, |
| | | vocab_size=self.vocab_size, |
| | | token_list=None, |
| | | pre_beam_score_key="full", |
| | | ) |
| | | |
| | | self.beam_search = beam_search |
| | | |
| | | def inference( |
| | | self, |
| | |
| | | ): |
| | | if kwargs.get("batch_size", 1) > 1: |
| | | raise NotImplementedError("batch decoding is not implemented") |
| | | |
| | | # init beamsearch |
| | | if not hasattr(self, "beam_search") or self.beam_search is None: |
| | | logging.info("enable beam_search") |
| | | self.init_beam_search(**kwargs) |
| | | self.nbest = kwargs.get("nbest", 1) |
| | | |
| | | if frontend is None and not hasattr(self, "frontend"): |
| | | frontend_class = tables.frontend_classes.get("WhisperFrontend") |
| | |
| | | task = [task] |
| | | task = "".join([f"<|{x}|>" for x in task]) |
| | | initial_prompt = kwargs.get("initial_prompt", f"<|startoftranscript|>{task}") |
| | | DecodingOptions["initial_prompt"] = initial_prompt |
| | | |
| | | language = DecodingOptions.get("language", None) |
| | | language = None if language == "auto" else language |
| | | DecodingOptions["language"] = language |
| | | |
| | | DecodingOptions["vocab_path"] = kwargs["tokenizer_conf"].get("vocab_path", None) |
| | | sos = f"{initial_prompt}<|{language}|>" if language is not None else initial_prompt |
| | | sos_int = tokenizer.encode(sos, allowed_special="all") |
| | | eos = kwargs.get("model_conf").get("eos") |
| | | eos_int = tokenizer.encode(eos, allowed_special="all") |
| | | self.beam_search.sos = sos_int |
| | | self.beam_search.eos = eos_int[0] |
| | | |
| | | if "without_timestamps" not in DecodingOptions: |
| | | DecodingOptions["without_timestamps"] = True |
| | | encoder_out, encoder_out_lens = self.encode( |
| | | speech[None, :, :].permute(0, 2, 1), speech_lengths |
| | | ) |
| | | |
| | | options = whisper.DecodingOptions(**DecodingOptions) |
| | | # c. Passed the encoder result and the beam search |
| | | nbest_hyps = self.beam_search( |
| | | x=encoder_out[0], |
| | | maxlenratio=kwargs.get("maxlenratio", 0.0), |
| | | minlenratio=kwargs.get("minlenratio", 0.0), |
| | | ) |
| | | |
| | | result = whisper.decode(self.model, speech, options) |
| | | text = f"{result.text}" |
| | | nbest_hyps = nbest_hyps[: self.nbest] |
| | | |
| | | results = [] |
| | | result_i = {"key": key[0], "text": text} |
| | | b, n, d = encoder_out.size() |
| | | for i in range(b): |
| | | |
| | | results.append(result_i) |
| | | for nbest_idx, hyp in enumerate(nbest_hyps): |
| | | ibest_writer = None |
| | | if kwargs.get("output_dir") is not None: |
| | | if not hasattr(self, "writer"): |
| | | self.writer = DatadirWriter(kwargs.get("output_dir")) |
| | | ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"] |
| | | |
| | | # remove sos/eos and get results |
| | | last_pos = -1 |
| | | if isinstance(hyp.yseq, list): |
| | | token_int = hyp.yseq[1:last_pos] |
| | | else: |
| | | token_int = hyp.yseq[1:last_pos].tolist() |
| | | |
| | | # # remove blank symbol id, which is assumed to be 0 |
| | | # token_int = list( |
| | | # filter( |
| | | # lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int |
| | | # ) |
| | | # ) |
| | | |
| | | # Change integer-ids to tokens |
| | | # token = tokenizer.ids2tokens(token_int) |
| | | text = tokenizer.decode(token_int) |
| | | |
| | | result_i = {"key": key[i], "text": text} |
| | | results.append(result_i) |
| | | |
| | | if ibest_writer is not None: |
| | | # ibest_writer["token"][key[i]] = " ".join(token) |
| | | ibest_writer["text"][key[i]] = text |
| | | |
| | | return results, meta_data |
| | | |
| | |
| | | # decoder |
| | | del model.decoder |
| | | decoder = kwargs.get("decoder", "SenseVoiceDecoder") |
| | | decoder_conf = kwargs.get("decoder_conf", {}) |
| | | decoder_class = tables.decoder_classes.get(decoder) |
| | | decoder = decoder_class( |
| | | vocab_size=dims.n_vocab, |
| | | encoder_output_size=dims.n_audio_state, |
| | | **decoder_conf, |
| | | n_vocab=dims.n_vocab, |
| | | n_ctx=dims.n_text_ctx, |
| | | n_state=dims.n_text_state, |
| | | n_head=dims.n_text_head, |
| | | n_layer=dims.n_text_layer, |
| | | **kwargs.get("decoder_conf"), |
| | | ) |
| | | model.decoder = decoder |
| | | |
| | |
| | | |
| | | self.activation_checkpoint = kwargs.get("activation_checkpoint", False) |
| | | self.ignore_id = kwargs.get("ignore_id", -1) |
| | | self.vocab_size = kwargs.get("vocab_size", -1) |
| | | self.vocab_size = dims.n_vocab |
| | | self.length_normalized_loss = kwargs.get("length_normalized_loss", True) |
| | | self.criterion_att = LabelSmoothingLoss( |
| | | size=self.vocab_size, |
| | |
| | | |
| | | return loss_att, acc_att, None, None |
| | | |
| | | def init_beam_search( |
| | | self, |
| | | **kwargs, |
| | | ): |
| | | from .search import BeamSearch |
| | | |
| | | from funasr.models.transformer.scorers.length_bonus import LengthBonus |
| | | |
| | | # 1. Build ASR model |
| | | scorers = {} |
| | | |
| | | scorers.update( |
| | | decoder=self.model.decoder, |
| | | length_bonus=LengthBonus(self.vocab_size), |
| | | ) |
| | | |
| | | weights = dict( |
| | | decoder=1.0, |
| | | ctc=0.0, |
| | | lm=0.0, |
| | | ngram=0.0, |
| | | length_bonus=kwargs.get("penalty", 0.0), |
| | | ) |
| | | beam_search = BeamSearch( |
| | | beam_size=kwargs.get("beam_size", 5), |
| | | weights=weights, |
| | | scorers=scorers, |
| | | sos=None, |
| | | eos=None, |
| | | vocab_size=self.vocab_size, |
| | | token_list=None, |
| | | pre_beam_score_key="full", |
| | | ) |
| | | |
| | | self.beam_search = beam_search |
| | | |
| | | def inference( |
| | | self, |
| | | data_in, |
| | |
| | | ): |
| | | if kwargs.get("batch_size", 1) > 1: |
| | | raise NotImplementedError("batch decoding is not implemented") |
| | | |
| | | # init beamsearch |
| | | if not hasattr(self, "beam_search") or self.beam_search is None: |
| | | logging.info("enable beam_search") |
| | | self.init_beam_search(**kwargs) |
| | | self.nbest = kwargs.get("nbest", 1) |
| | | |
| | | if frontend is None and not hasattr(self, "frontend"): |
| | | frontend_class = tables.frontend_classes.get("WhisperFrontend") |
| | |
| | | data_type=kwargs.get("data_type", "sound"), |
| | | tokenizer=tokenizer, |
| | | ) |
| | | |
| | | if ( |
| | | isinstance(kwargs.get("data_type", None), (list, tuple)) |
| | | and len(kwargs.get("data_type", [])) > 1 |
| | | ): |
| | | audio_sample_list, text_token_int_list = audio_sample_list |
| | | text_token_int = text_token_int_list[0] |
| | | else: |
| | | text_token_int = None |
| | | |
| | | time2 = time.perf_counter() |
| | | meta_data["load_data"] = f"{time2 - time1:0.3f}" |
| | | speech, speech_lengths = extract_fbank( |
| | |
| | | task = [task] |
| | | task = "".join([f"<|{x}|>" for x in task]) |
| | | initial_prompt = kwargs.get("initial_prompt", f"<|startoftranscript|>{task}") |
| | | DecodingOptions["initial_prompt"] = initial_prompt |
| | | |
| | | language = DecodingOptions.get("language", None) |
| | | language = None if language == "auto" else language |
| | | DecodingOptions["language"] = language |
| | | |
| | | DecodingOptions["vocab_path"] = kwargs["tokenizer_conf"].get("vocab_path", None) |
| | | sos = f"{initial_prompt}<|{language}|>" if language is not None else initial_prompt |
| | | sos_int = tokenizer.encode(sos, allowed_special="all") |
| | | eos = kwargs.get("model_conf").get("eos") |
| | | eos_int = tokenizer.encode(eos, allowed_special="all") |
| | | self.beam_search.sos = sos_int |
| | | self.beam_search.eos = eos_int[0] |
| | | |
| | | if "without_timestamps" not in DecodingOptions: |
| | | DecodingOptions["without_timestamps"] = True |
| | | encoder_out, encoder_out_lens = self.encode( |
| | | speech[None, :, :].permute(0, 2, 1), speech_lengths |
| | | ) |
| | | |
| | | options = whisper.DecodingOptions(**DecodingOptions) |
| | | if text_token_int is not None: |
| | | i = 0 |
| | | results = [] |
| | | ibest_writer = None |
| | | if kwargs.get("output_dir") is not None: |
| | | if not hasattr(self, "writer"): |
| | | self.writer = DatadirWriter(kwargs.get("output_dir")) |
| | | ibest_writer = self.writer[f"1best_recog"] |
| | | |
| | | result = whisper.decode(self.model, speech, options) |
| | | text = f"{result.text}" |
| | | # 1. Forward decoder |
| | | ys_pad = torch.tensor(sos_int + text_token_int, dtype=torch.int64).to(kwargs["device"])[ |
| | | None, : |
| | | ] |
| | | ys_pad_lens = torch.tensor([len(sos_int + text_token_int)], dtype=torch.int64).to( |
| | | kwargs["device"] |
| | | )[None, :] |
| | | decoder_out = self.model.decoder( |
| | | x=ys_pad, xa=encoder_out, hlens=encoder_out_lens, ys_in_lens=ys_pad_lens |
| | | ) |
| | | |
| | | token_int = decoder_out.argmax(-1)[0, :].tolist() |
| | | text = tokenizer.decode(token_int) |
| | | |
| | | result_i = {"key": key[i], "text": text} |
| | | results.append(result_i) |
| | | |
| | | if ibest_writer is not None: |
| | | # ibest_writer["token"][key[i]] = " ".join(token) |
| | | ibest_writer["text"][key[i]] = text |
| | | return results, meta_data |
| | | |
| | | # c. Passed the encoder result and the beam search |
| | | nbest_hyps = self.beam_search( |
| | | x=encoder_out[0], |
| | | maxlenratio=kwargs.get("maxlenratio", 0.0), |
| | | minlenratio=kwargs.get("minlenratio", 0.0), |
| | | ) |
| | | |
| | | nbest_hyps = nbest_hyps[: self.nbest] |
| | | |
| | | results = [] |
| | | result_i = {"key": key[0], "text": text} |
| | | b, n, d = encoder_out.size() |
| | | for i in range(b): |
| | | |
| | | results.append(result_i) |
| | | for nbest_idx, hyp in enumerate(nbest_hyps): |
| | | ibest_writer = None |
| | | if kwargs.get("output_dir") is not None: |
| | | if not hasattr(self, "writer"): |
| | | self.writer = DatadirWriter(kwargs.get("output_dir")) |
| | | ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"] |
| | | |
| | | # remove sos/eos and get results |
| | | last_pos = -1 |
| | | if isinstance(hyp.yseq, list): |
| | | token_int = hyp.yseq[1:last_pos] |
| | | else: |
| | | token_int = hyp.yseq[1:last_pos].tolist() |
| | | |
| | | # # remove blank symbol id, which is assumed to be 0 |
| | | # token_int = list( |
| | | # filter( |
| | | # lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int |
| | | # ) |
| | | # ) |
| | | |
| | | # Change integer-ids to tokens |
| | | # token = tokenizer.ids2tokens(token_int) |
| | | text = tokenizer.decode(token_int) |
| | | |
| | | result_i = {"key": key[i], "text": text} |
| | | results.append(result_i) |
| | | |
| | | if ibest_writer is not None: |
| | | # ibest_writer["token"][key[i]] = " ".join(token) |
| | | ibest_writer["text"][key[i]] = text |
| | | |
| | | return results, meta_data |