From 2f27b165559cd53afab52047309ebe4ac838ebb8 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 14 五月 2024 09:54:08 +0800
Subject: [PATCH] Add files via upload

---
 funasr/models/sense_voice/model.py |   94 +++++++++++++++++++++++++++++++++++++++++++++-
 1 files changed, 91 insertions(+), 3 deletions(-)

diff --git a/funasr/models/sense_voice/model.py b/funasr/models/sense_voice/model.py
index d5e4130..56e61e7 100644
--- a/funasr/models/sense_voice/model.py
+++ b/funasr/models/sense_voice/model.py
@@ -378,14 +378,19 @@
         stats = {}
 
         # 1. Forward decoder
+        # ys_pad: [sos, task, lid, text, eos]
         decoder_out = self.model.decoder(
             x=ys_pad, xa=encoder_out, hlens=encoder_out_lens, ys_in_lens=ys_pad_lens
         )
 
         # 2. Compute attention loss
-        mask = torch.ones_like(ys_pad) * (-1)
-        ys_pad_mask = (ys_pad * target_mask + mask * (1 - target_mask)).to(torch.int64)
-        ys_pad_mask[ys_pad_mask == 0] = -1
+        mask = torch.ones_like(ys_pad) * (-1)  # [sos, task, lid, text, eos]: [-1, -1, -1, -1]
+        ys_pad_mask = (ys_pad * target_mask + mask * (1 - target_mask)).to(
+            torch.int64
+        )  # [sos, task, lid, text, eos]: [0, 0, 1, 1, 1] + [-1, -1, 0, 0, 0]
+        ys_pad_mask[ys_pad_mask == 0] = -1  # [-1, -1, lid, text, eos]
+        # decoder_out: [sos, task, lid, text]
+        # ys_pad_mask: [-1, lid, text, eos]
         loss_att = self.criterion_att(decoder_out[:, :-1, :], ys_pad_mask[:, 1:])
 
         with torch.no_grad():
@@ -508,6 +513,27 @@
         eos_int = tokenizer.encode(eos, allowed_special="all")
         self.beam_search.sos = sos_int
         self.beam_search.eos = eos_int[0]
+
+        # Paramterts for rich decoding
+        self.beam_search.emo_unk = tokenizer.encode(
+            DecodingOptions.get("emo_unk_token", "<|SPECIAL_TOKEN_1|>"), allowed_special="all"
+        )[0]
+        self.beam_search.emo_unk_score = 1
+        self.beam_search.emo_tokens = tokenizer.encode(
+            DecodingOptions.get("emo_target_tokens", "<|HAPPY|><|SAD|><|ANGRY|>"),
+            allowed_special="all",
+        )
+        self.beam_search.emo_scores = DecodingOptions.get("emo_target_threshold", [0.1, 0.1, 0.1])
+
+        self.beam_search.event_bg_token = tokenizer.encode(
+            DecodingOptions.get("gain_tokens_bg", "<|Speech|><|BGM|><|Applause|><|Laughter|>"),
+            allowed_special="all",
+        )
+        self.beam_search.event_ed_token = tokenizer.encode(
+            DecodingOptions.get("gain_tokens_ed", "<|/Speech|><|/BGM|><|/Applause|><|/Laughter|>"),
+            allowed_special="all",
+        )
+        self.beam_search.event_score_ga = DecodingOptions.get("gain_tokens_score", [1, 1, 1, 1])
 
         encoder_out, encoder_out_lens = self.encode(
             speech[None, :, :].permute(0, 2, 1), speech_lengths
@@ -797,6 +823,16 @@
                 data_type=kwargs.get("data_type", "sound"),
                 tokenizer=tokenizer,
             )
+
+            if (
+                isinstance(kwargs.get("data_type", None), (list, tuple))
+                and len(kwargs.get("data_type", [])) > 1
+            ):
+                audio_sample_list, text_token_int_list = audio_sample_list
+                text_token_int = text_token_int_list[0]
+            else:
+                text_token_int = None
+
             time2 = time.perf_counter()
             meta_data["load_data"] = f"{time2 - time1:0.3f}"
             speech, speech_lengths = extract_fbank(
@@ -828,10 +864,62 @@
         self.beam_search.sos = sos_int
         self.beam_search.eos = eos_int[0]
 
+        # Paramterts for rich decoding
+        self.beam_search.emo_unk = tokenizer.encode(
+            DecodingOptions.get("emo_unk_token", "<|SPECIAL_TOKEN_1|>"), allowed_special="all"
+        )[0]
+        self.beam_search.emo_unk_score = 1
+        self.beam_search.emo_tokens = tokenizer.encode(
+            DecodingOptions.get("emo_target_tokens", "<|HAPPY|><|SAD|><|ANGRY|>"),
+            allowed_special="all",
+        )
+        self.beam_search.emo_scores = DecodingOptions.get("emo_target_threshold", [0.1, 0.1, 0.1])
+
+        self.beam_search.event_bg_token = tokenizer.encode(
+            DecodingOptions.get("gain_tokens_bg", "<|Speech|><|BGM|><|Applause|><|Laughter|>"),
+            allowed_special="all",
+        )
+        self.beam_search.event_ed_token = tokenizer.encode(
+            DecodingOptions.get("gain_tokens_ed", "<|/Speech|><|/BGM|><|/Applause|><|/Laughter|>"),
+            allowed_special="all",
+        )
+        self.beam_search.event_score_ga = DecodingOptions.get("gain_tokens_score", [1, 1, 1, 1])
+
         encoder_out, encoder_out_lens = self.encode(
             speech[None, :, :].permute(0, 2, 1), speech_lengths
         )
 
+        if text_token_int is not None:
+            i = 0
+            results = []
+            ibest_writer = None
+            if kwargs.get("output_dir") is not None:
+                if not hasattr(self, "writer"):
+                    self.writer = DatadirWriter(kwargs.get("output_dir"))
+                ibest_writer = self.writer[f"1best_recog"]
+
+            # 1. Forward decoder
+            ys_pad = torch.tensor(sos_int + text_token_int, dtype=torch.int64).to(kwargs["device"])[
+                None, :
+            ]
+            ys_pad_lens = torch.tensor([len(sos_int + text_token_int)], dtype=torch.int64).to(
+                kwargs["device"]
+            )[None, :]
+            decoder_out = self.model.decoder(
+                x=ys_pad, xa=encoder_out, hlens=encoder_out_lens, ys_in_lens=ys_pad_lens
+            )
+
+            token_int = decoder_out.argmax(-1)[0, :].tolist()
+            text = tokenizer.decode(token_int)
+
+            result_i = {"key": key[i], "text": text}
+            results.append(result_i)
+
+            if ibest_writer is not None:
+                # ibest_writer["token"][key[i]] = " ".join(token)
+                ibest_writer["text"][key[i]] = text
+            return results, meta_data
+
         # c. Passed the encoder result and the beam search
         nbest_hyps = self.beam_search(
             x=encoder_out[0],

--
Gitblit v1.9.1