From cfe577f16fef9fb5b0a48f07d4f9e232799cc9d4 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 08 五月 2024 00:03:52 +0800
Subject: [PATCH] decoding key

---
 funasr/models/sense_voice/model.py |  269 ++++++++++++++++++++++++++++++++++++++++++++++++-----
 1 files changed, 241 insertions(+), 28 deletions(-)

diff --git a/funasr/models/sense_voice/model.py b/funasr/models/sense_voice/model.py
index c12107e..0230638 100644
--- a/funasr/models/sense_voice/model.py
+++ b/funasr/models/sense_voice/model.py
@@ -15,6 +15,7 @@
 from funasr.train_utils.device_funcs import force_gatherable
 from . import whisper_lib as whisper
 from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
+from funasr.utils.datadir_writer import DatadirWriter
 
 from funasr.register import tables
 
@@ -377,14 +378,19 @@
         stats = {}
 
         # 1. Forward decoder
+        # ys_pad: [sos, task, lid, text, eos]
         decoder_out = self.model.decoder(
             x=ys_pad, xa=encoder_out, hlens=encoder_out_lens, ys_in_lens=ys_pad_lens
         )
 
         # 2. Compute attention loss
-        mask = torch.ones_like(ys_pad) * (-1)
-        ys_pad_mask = (ys_pad * target_mask + mask * (1 - target_mask)).to(torch.int64)
-        ys_pad_mask[ys_pad_mask == 0] = -1
+        mask = torch.ones_like(ys_pad) * (-1)  # [sos, task, lid, text, eos]: [-1, -1, -1, -1]
+        ys_pad_mask = (ys_pad * target_mask + mask * (1 - target_mask)).to(
+            torch.int64
+        )  # [sos, task, lid, text, eos]: [0, 0, 1, 1, 1] + [-1, -1, 0, 0, 0]
+        ys_pad_mask[ys_pad_mask == 0] = -1  # [-1, -1, lid, text, eos]
+        # decoder_out: [sos, task, lid, text]
+        # ys_pad_mask: [-1, lid, text, eos]
         loss_att = self.criterion_att(decoder_out[:, :-1, :], ys_pad_mask[:, 1:])
 
         with torch.no_grad():
@@ -394,6 +400,42 @@
             )
 
         return loss_att, acc_att, None, None
+
+    def init_beam_search(
+        self,
+        **kwargs,
+    ):
+        from .search import BeamSearch
+
+        from funasr.models.transformer.scorers.length_bonus import LengthBonus
+
+        # 1. Build ASR model
+        scorers = {}
+
+        scorers.update(
+            decoder=self.model.decoder,
+            length_bonus=LengthBonus(self.vocab_size),
+        )
+
+        weights = dict(
+            decoder=1.0,
+            ctc=0.0,
+            lm=0.0,
+            ngram=0.0,
+            length_bonus=kwargs.get("penalty", 0.0),
+        )
+        beam_search = BeamSearch(
+            beam_size=kwargs.get("beam_size", 5),
+            weights=weights,
+            scorers=scorers,
+            sos=None,
+            eos=None,
+            vocab_size=self.vocab_size,
+            token_list=None,
+            pre_beam_score_key="full",
+        )
+
+        self.beam_search = beam_search
 
     def inference(
         self,
@@ -406,6 +448,12 @@
     ):
         if kwargs.get("batch_size", 1) > 1:
             raise NotImplementedError("batch decoding is not implemented")
+
+        # init beamsearch
+        if not hasattr(self, "beam_search") or self.beam_search is None:
+            logging.info("enable beam_search")
+            self.init_beam_search(**kwargs)
+            self.nbest = kwargs.get("nbest", 1)
 
         if frontend is None and not hasattr(self, "frontend"):
             frontend_class = tables.frontend_classes.get("WhisperFrontend")
@@ -455,25 +503,65 @@
             task = [task]
         task = "".join([f"<|{x}|>" for x in task])
         initial_prompt = kwargs.get("initial_prompt", f"<|startoftranscript|>{task}")
-        DecodingOptions["initial_prompt"] = initial_prompt
 
         language = DecodingOptions.get("language", None)
         language = None if language == "auto" else language
-        DecodingOptions["language"] = language
 
-        DecodingOptions["vocab_path"] = kwargs["tokenizer_conf"].get("vocab_path", None)
+        sos = f"{initial_prompt}<|{language}|>" if language is not None else initial_prompt
+        sos_int = tokenizer.encode(sos, allowed_special="all")
+        eos = kwargs.get("model_conf").get("eos")
+        eos_int = tokenizer.encode(eos, allowed_special="all")
+        self.beam_search.sos = sos_int
+        self.beam_search.eos = eos_int[0]
 
-        if "without_timestamps" not in DecodingOptions:
-            DecodingOptions["without_timestamps"] = True
+        encoder_out, encoder_out_lens = self.encode(
+            speech[None, :, :].permute(0, 2, 1), speech_lengths
+        )
 
-        options = whisper.DecodingOptions(**DecodingOptions)
+        # c. Passed the encoder result and the beam search
+        nbest_hyps = self.beam_search(
+            x=encoder_out[0],
+            maxlenratio=kwargs.get("maxlenratio", 0.0),
+            minlenratio=kwargs.get("minlenratio", 0.0),
+        )
 
-        result = whisper.decode(self.model, speech, options)
-        text = f"{result.text}"
+        nbest_hyps = nbest_hyps[: self.nbest]
+
         results = []
-        result_i = {"key": key[0], "text": text}
+        b, n, d = encoder_out.size()
+        for i in range(b):
 
-        results.append(result_i)
+            for nbest_idx, hyp in enumerate(nbest_hyps):
+                ibest_writer = None
+                if kwargs.get("output_dir") is not None:
+                    if not hasattr(self, "writer"):
+                        self.writer = DatadirWriter(kwargs.get("output_dir"))
+                    ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"]
+
+                # remove sos/eos and get results
+                last_pos = -1
+                if isinstance(hyp.yseq, list):
+                    token_int = hyp.yseq[1:last_pos]
+                else:
+                    token_int = hyp.yseq[1:last_pos].tolist()
+
+                # # remove blank symbol id, which is assumed to be 0
+                # token_int = list(
+                #     filter(
+                #         lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int
+                #     )
+                # )
+
+                # Change integer-ids to tokens
+                # token = tokenizer.ids2tokens(token_int)
+                text = tokenizer.decode(token_int)
+
+                result_i = {"key": key[i], "text": text}
+                results.append(result_i)
+
+                if ibest_writer is not None:
+                    # ibest_writer["token"][key[i]] = " ".join(token)
+                    ibest_writer["text"][key[i]] = text
 
         return results, meta_data
 
@@ -497,12 +585,14 @@
         # decoder
         del model.decoder
         decoder = kwargs.get("decoder", "SenseVoiceDecoder")
-        decoder_conf = kwargs.get("decoder_conf", {})
         decoder_class = tables.decoder_classes.get(decoder)
         decoder = decoder_class(
-            vocab_size=dims.n_vocab,
-            encoder_output_size=dims.n_audio_state,
-            **decoder_conf,
+            n_vocab=dims.n_vocab,
+            n_ctx=dims.n_text_ctx,
+            n_state=dims.n_text_state,
+            n_head=dims.n_text_head,
+            n_layer=dims.n_text_layer,
+            **kwargs.get("decoder_conf"),
         )
         model.decoder = decoder
 
@@ -512,7 +602,7 @@
 
         self.activation_checkpoint = kwargs.get("activation_checkpoint", False)
         self.ignore_id = kwargs.get("ignore_id", -1)
-        self.vocab_size = kwargs.get("vocab_size", -1)
+        self.vocab_size = dims.n_vocab
         self.length_normalized_loss = kwargs.get("length_normalized_loss", True)
         self.criterion_att = LabelSmoothingLoss(
             size=self.vocab_size,
@@ -630,6 +720,42 @@
 
         return loss_att, acc_att, None, None
 
+    def init_beam_search(
+        self,
+        **kwargs,
+    ):
+        from .search import BeamSearch
+
+        from funasr.models.transformer.scorers.length_bonus import LengthBonus
+
+        # 1. Build ASR model
+        scorers = {}
+
+        scorers.update(
+            decoder=self.model.decoder,
+            length_bonus=LengthBonus(self.vocab_size),
+        )
+
+        weights = dict(
+            decoder=1.0,
+            ctc=0.0,
+            lm=0.0,
+            ngram=0.0,
+            length_bonus=kwargs.get("penalty", 0.0),
+        )
+        beam_search = BeamSearch(
+            beam_size=kwargs.get("beam_size", 5),
+            weights=weights,
+            scorers=scorers,
+            sos=None,
+            eos=None,
+            vocab_size=self.vocab_size,
+            token_list=None,
+            pre_beam_score_key="full",
+        )
+
+        self.beam_search = beam_search
+
     def inference(
         self,
         data_in,
@@ -641,6 +767,12 @@
     ):
         if kwargs.get("batch_size", 1) > 1:
             raise NotImplementedError("batch decoding is not implemented")
+
+        # init beamsearch
+        if not hasattr(self, "beam_search") or self.beam_search is None:
+            logging.info("enable beam_search")
+            self.init_beam_search(**kwargs)
+            self.nbest = kwargs.get("nbest", 1)
 
         if frontend is None and not hasattr(self, "frontend"):
             frontend_class = tables.frontend_classes.get("WhisperFrontend")
@@ -670,6 +802,16 @@
                 data_type=kwargs.get("data_type", "sound"),
                 tokenizer=tokenizer,
             )
+
+            if (
+                isinstance(kwargs.get("data_type", None), (list, tuple))
+                and len(kwargs.get("data_type", [])) > 1
+            ):
+                audio_sample_list, text_token_int_list = audio_sample_list
+                text_token_int = text_token_int_list[0]
+            else:
+                text_token_int = None
+
             time2 = time.perf_counter()
             meta_data["load_data"] = f"{time2 - time1:0.3f}"
             speech, speech_lengths = extract_fbank(
@@ -690,24 +832,95 @@
             task = [task]
         task = "".join([f"<|{x}|>" for x in task])
         initial_prompt = kwargs.get("initial_prompt", f"<|startoftranscript|>{task}")
-        DecodingOptions["initial_prompt"] = initial_prompt
 
         language = DecodingOptions.get("language", None)
         language = None if language == "auto" else language
-        DecodingOptions["language"] = language
 
-        DecodingOptions["vocab_path"] = kwargs["tokenizer_conf"].get("vocab_path", None)
+        sos = f"{initial_prompt}<|{language}|>" if language is not None else initial_prompt
+        sos_int = tokenizer.encode(sos, allowed_special="all")
+        eos = kwargs.get("model_conf").get("eos")
+        eos_int = tokenizer.encode(eos, allowed_special="all")
+        self.beam_search.sos = sos_int
+        self.beam_search.eos = eos_int[0]
 
-        if "without_timestamps" not in DecodingOptions:
-            DecodingOptions["without_timestamps"] = True
+        encoder_out, encoder_out_lens = self.encode(
+            speech[None, :, :].permute(0, 2, 1), speech_lengths
+        )
 
-        options = whisper.DecodingOptions(**DecodingOptions)
+        if text_token_int is not None:
+            i = 0
+            results = []
+            ibest_writer = None
+            if kwargs.get("output_dir") is not None:
+                if not hasattr(self, "writer"):
+                    self.writer = DatadirWriter(kwargs.get("output_dir"))
+                ibest_writer = self.writer[f"1best_recog"]
 
-        result = whisper.decode(self.model, speech, options)
-        text = f"{result.text}"
+            # 1. Forward decoder
+            ys_pad = torch.tensor(sos_int + text_token_int, dtype=torch.int64).to(kwargs["device"])[
+                None, :
+            ]
+            ys_pad_lens = torch.tensor([len(sos_int + text_token_int)], dtype=torch.int64).to(
+                kwargs["device"]
+            )[None, :]
+            decoder_out = self.model.decoder(
+                x=ys_pad, xa=encoder_out, hlens=encoder_out_lens, ys_in_lens=ys_pad_lens
+            )
+
+            token_int = decoder_out.argmax(-1)[0, :].tolist()
+            text = tokenizer.decode(token_int)
+
+            result_i = {"key": key[i], "text": text}
+            results.append(result_i)
+
+            if ibest_writer is not None:
+                # ibest_writer["token"][key[i]] = " ".join(token)
+                ibest_writer["text"][key[i]] = text
+            return results, meta_data
+
+        # c. Passed the encoder result and the beam search
+        nbest_hyps = self.beam_search(
+            x=encoder_out[0],
+            maxlenratio=kwargs.get("maxlenratio", 0.0),
+            minlenratio=kwargs.get("minlenratio", 0.0),
+        )
+
+        nbest_hyps = nbest_hyps[: self.nbest]
+
         results = []
-        result_i = {"key": key[0], "text": text}
+        b, n, d = encoder_out.size()
+        for i in range(b):
 
-        results.append(result_i)
+            for nbest_idx, hyp in enumerate(nbest_hyps):
+                ibest_writer = None
+                if kwargs.get("output_dir") is not None:
+                    if not hasattr(self, "writer"):
+                        self.writer = DatadirWriter(kwargs.get("output_dir"))
+                    ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"]
+
+                # remove sos/eos and get results
+                last_pos = -1
+                if isinstance(hyp.yseq, list):
+                    token_int = hyp.yseq[1:last_pos]
+                else:
+                    token_int = hyp.yseq[1:last_pos].tolist()
+
+                # # remove blank symbol id, which is assumed to be 0
+                # token_int = list(
+                #     filter(
+                #         lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int
+                #     )
+                # )
+
+                # Change integer-ids to tokens
+                # token = tokenizer.ids2tokens(token_int)
+                text = tokenizer.decode(token_int)
+
+                result_i = {"key": key[i], "text": text}
+                results.append(result_i)
+
+                if ibest_writer is not None:
+                    # ibest_writer["token"][key[i]] = " ".join(token)
+                    ibest_writer["text"][key[i]] = text
 
         return results, meta_data

--
Gitblit v1.9.1