zhifu gao
2024-05-06 00d0df3a1018c63ec8c5d13e611f53c564c0a7e2
funasr/models/sense_voice/model.py
@@ -15,6 +15,7 @@
from funasr.train_utils.device_funcs import force_gatherable
from . import whisper_lib as whisper
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
from funasr.utils.datadir_writer import DatadirWriter
from funasr.register import tables
@@ -395,6 +396,42 @@
        return loss_att, acc_att, None, None
    def init_beam_search(
        self,
        **kwargs,
    ):
        from .search import BeamSearch
        from funasr.models.transformer.scorers.length_bonus import LengthBonus
        # 1. Build ASR model
        scorers = {}
        scorers.update(
            decoder=self.model.decoder,
            length_bonus=LengthBonus(self.vocab_size),
        )
        weights = dict(
            decoder=1.0,
            ctc=0.0,
            lm=0.0,
            ngram=0.0,
            length_bonus=kwargs.get("penalty", 0.0),
        )
        beam_search = BeamSearch(
            beam_size=kwargs.get("beam_size", 5),
            weights=weights,
            scorers=scorers,
            sos=None,
            eos=None,
            vocab_size=self.vocab_size,
            token_list=None,
            pre_beam_score_key="full",
        )
        self.beam_search = beam_search
    def inference(
        self,
        data_in,
@@ -406,6 +443,12 @@
    ):
        if kwargs.get("batch_size", 1) > 1:
            raise NotImplementedError("batch decoding is not implemented")
        # init beamsearch
        if not hasattr(self, "beam_search") or self.beam_search is None:
            logging.info("enable beam_search")
            self.init_beam_search(**kwargs)
            self.nbest = kwargs.get("nbest", 1)
        if frontend is None and not hasattr(self, "frontend"):
            frontend_class = tables.frontend_classes.get("WhisperFrontend")
@@ -455,25 +498,65 @@
            task = [task]
        task = "".join([f"<|{x}|>" for x in task])
        initial_prompt = kwargs.get("initial_prompt", f"<|startoftranscript|>{task}")
        DecodingOptions["initial_prompt"] = initial_prompt
        language = DecodingOptions.get("language", None)
        language = None if language == "auto" else language
        DecodingOptions["language"] = language
        DecodingOptions["vocab_path"] = kwargs["tokenizer_conf"].get("vocab_path", None)
        sos = f"{initial_prompt}<|{language}|>" if language is not None else initial_prompt
        sos_int = tokenizer.encode(sos, allowed_special="all")
        eos = kwargs.get("model_conf").get("eos")
        eos_int = tokenizer.encode(eos, allowed_special="all")
        self.beam_search.sos = sos_int
        self.beam_search.eos = eos_int[0]
        if "without_timestamps" not in DecodingOptions:
            DecodingOptions["without_timestamps"] = True
        encoder_out, encoder_out_lens = self.encode(
            speech[None, :, :].permute(0, 2, 1), speech_lengths
        )
        options = whisper.DecodingOptions(**DecodingOptions)
        # c. Passed the encoder result and the beam search
        nbest_hyps = self.beam_search(
            x=encoder_out[0],
            maxlenratio=kwargs.get("maxlenratio", 0.0),
            minlenratio=kwargs.get("minlenratio", 0.0),
        )
        result = whisper.decode(self.model, speech, options)
        text = f"{result.text}"
        nbest_hyps = nbest_hyps[: self.nbest]
        results = []
        result_i = {"key": key[0], "text": text}
        b, n, d = encoder_out.size()
        for i in range(b):
        results.append(result_i)
            for nbest_idx, hyp in enumerate(nbest_hyps):
                ibest_writer = None
                if kwargs.get("output_dir") is not None:
                    if not hasattr(self, "writer"):
                        self.writer = DatadirWriter(kwargs.get("output_dir"))
                    ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"]
                # remove sos/eos and get results
                last_pos = -1
                if isinstance(hyp.yseq, list):
                    token_int = hyp.yseq[1:last_pos]
                else:
                    token_int = hyp.yseq[1:last_pos].tolist()
                # # remove blank symbol id, which is assumed to be 0
                # token_int = list(
                #     filter(
                #         lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int
                #     )
                # )
                # Change integer-ids to tokens
                # token = tokenizer.ids2tokens(token_int)
                text = tokenizer.decode(token_int)
                result_i = {"key": key[i], "text": text}
                results.append(result_i)
                if ibest_writer is not None:
                    # ibest_writer["token"][key[i]] = " ".join(token)
                    ibest_writer["text"][key[i]] = text
        return results, meta_data
@@ -497,12 +580,14 @@
        # decoder
        del model.decoder
        decoder = kwargs.get("decoder", "SenseVoiceDecoder")
        decoder_conf = kwargs.get("decoder_conf", {})
        decoder_class = tables.decoder_classes.get(decoder)
        decoder = decoder_class(
            vocab_size=dims.n_vocab,
            encoder_output_size=dims.n_audio_state,
            **decoder_conf,
            n_vocab=dims.n_vocab,
            n_ctx=dims.n_text_ctx,
            n_state=dims.n_text_state,
            n_head=dims.n_text_head,
            n_layer=dims.n_text_layer,
            **kwargs.get("decoder_conf"),
        )
        model.decoder = decoder
@@ -512,7 +597,7 @@
        self.activation_checkpoint = kwargs.get("activation_checkpoint", False)
        self.ignore_id = kwargs.get("ignore_id", -1)
        self.vocab_size = kwargs.get("vocab_size", -1)
        self.vocab_size = dims.n_vocab
        self.length_normalized_loss = kwargs.get("length_normalized_loss", True)
        self.criterion_att = LabelSmoothingLoss(
            size=self.vocab_size,
@@ -630,6 +715,42 @@
        return loss_att, acc_att, None, None
    def init_beam_search(
        self,
        **kwargs,
    ):
        from .search import BeamSearch
        from funasr.models.transformer.scorers.length_bonus import LengthBonus
        # 1. Build ASR model
        scorers = {}
        scorers.update(
            decoder=self.model.decoder,
            length_bonus=LengthBonus(self.vocab_size),
        )
        weights = dict(
            decoder=1.0,
            ctc=0.0,
            lm=0.0,
            ngram=0.0,
            length_bonus=kwargs.get("penalty", 0.0),
        )
        beam_search = BeamSearch(
            beam_size=kwargs.get("beam_size", 5),
            weights=weights,
            scorers=scorers,
            sos=None,
            eos=None,
            vocab_size=self.vocab_size,
            token_list=None,
            pre_beam_score_key="full",
        )
        self.beam_search = beam_search
    def inference(
        self,
        data_in,
@@ -641,6 +762,12 @@
    ):
        if kwargs.get("batch_size", 1) > 1:
            raise NotImplementedError("batch decoding is not implemented")
        # init beamsearch
        if not hasattr(self, "beam_search") or self.beam_search is None:
            logging.info("enable beam_search")
            self.init_beam_search(**kwargs)
            self.nbest = kwargs.get("nbest", 1)
        if frontend is None and not hasattr(self, "frontend"):
            frontend_class = tables.frontend_classes.get("WhisperFrontend")
@@ -690,24 +817,64 @@
            task = [task]
        task = "".join([f"<|{x}|>" for x in task])
        initial_prompt = kwargs.get("initial_prompt", f"<|startoftranscript|>{task}")
        DecodingOptions["initial_prompt"] = initial_prompt
        language = DecodingOptions.get("language", None)
        language = None if language == "auto" else language
        DecodingOptions["language"] = language
        DecodingOptions["vocab_path"] = kwargs["tokenizer_conf"].get("vocab_path", None)
        sos = f"{initial_prompt}<|{language}|>" if language is not None else initial_prompt
        sos_int = tokenizer.encode(sos, allowed_special="all")
        eos = kwargs.get("model_conf").get("eos")
        eos_int = tokenizer.encode(eos, allowed_special="all")
        self.beam_search.sos = sos_int
        self.beam_search.eos = eos_int[0]
        if "without_timestamps" not in DecodingOptions:
            DecodingOptions["without_timestamps"] = True
        encoder_out, encoder_out_lens = self.encode(
            speech[None, :, :].permute(0, 2, 1), speech_lengths
        )
        options = whisper.DecodingOptions(**DecodingOptions)
        # c. Passed the encoder result and the beam search
        nbest_hyps = self.beam_search(
            x=encoder_out[0],
            maxlenratio=kwargs.get("maxlenratio", 0.0),
            minlenratio=kwargs.get("minlenratio", 0.0),
        )
        result = whisper.decode(self.model, speech, options)
        text = f"{result.text}"
        nbest_hyps = nbest_hyps[: self.nbest]
        results = []
        result_i = {"key": key[0], "text": text}
        b, n, d = encoder_out.size()
        for i in range(b):
        results.append(result_i)
            for nbest_idx, hyp in enumerate(nbest_hyps):
                ibest_writer = None
                if kwargs.get("output_dir") is not None:
                    if not hasattr(self, "writer"):
                        self.writer = DatadirWriter(kwargs.get("output_dir"))
                    ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"]
                # remove sos/eos and get results
                last_pos = -1
                if isinstance(hyp.yseq, list):
                    token_int = hyp.yseq[1:last_pos]
                else:
                    token_int = hyp.yseq[1:last_pos].tolist()
                # # remove blank symbol id, which is assumed to be 0
                # token_int = list(
                #     filter(
                #         lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int
                #     )
                # )
                # Change integer-ids to tokens
                # token = tokenizer.ids2tokens(token_int)
                text = tokenizer.decode(token_int)
                result_i = {"key": key[i], "text": text}
                results.append(result_i)
                if ibest_writer is not None:
                    # ibest_writer["token"][key[i]] = " ".join(token)
                    ibest_writer["text"][key[i]] = text
        return results, meta_data