From 33a9e08dc9b65abc3f3e18d14253f95c79e0f749 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 08 五月 2024 19:20:43 +0800
Subject: [PATCH] dynamic batch

---
 funasr/models/sense_voice/model.py |   42 ++++++++++++++++++++++++++++++++++++++++++
 1 files changed, 42 insertions(+), 0 deletions(-)

diff --git a/funasr/models/sense_voice/model.py b/funasr/models/sense_voice/model.py
index 0230638..56e61e7 100644
--- a/funasr/models/sense_voice/model.py
+++ b/funasr/models/sense_voice/model.py
@@ -514,6 +514,27 @@
         self.beam_search.sos = sos_int
         self.beam_search.eos = eos_int[0]
 
+        # Paramterts for rich decoding
+        self.beam_search.emo_unk = tokenizer.encode(
+            DecodingOptions.get("emo_unk_token", "<|SPECIAL_TOKEN_1|>"), allowed_special="all"
+        )[0]
+        self.beam_search.emo_unk_score = 1
+        self.beam_search.emo_tokens = tokenizer.encode(
+            DecodingOptions.get("emo_target_tokens", "<|HAPPY|><|SAD|><|ANGRY|>"),
+            allowed_special="all",
+        )
+        self.beam_search.emo_scores = DecodingOptions.get("emo_target_threshold", [0.1, 0.1, 0.1])
+
+        self.beam_search.event_bg_token = tokenizer.encode(
+            DecodingOptions.get("gain_tokens_bg", "<|Speech|><|BGM|><|Applause|><|Laughter|>"),
+            allowed_special="all",
+        )
+        self.beam_search.event_ed_token = tokenizer.encode(
+            DecodingOptions.get("gain_tokens_ed", "<|/Speech|><|/BGM|><|/Applause|><|/Laughter|>"),
+            allowed_special="all",
+        )
+        self.beam_search.event_score_ga = DecodingOptions.get("gain_tokens_score", [1, 1, 1, 1])
+
         encoder_out, encoder_out_lens = self.encode(
             speech[None, :, :].permute(0, 2, 1), speech_lengths
         )
@@ -843,6 +864,27 @@
         self.beam_search.sos = sos_int
         self.beam_search.eos = eos_int[0]
 
+        # Paramterts for rich decoding
+        self.beam_search.emo_unk = tokenizer.encode(
+            DecodingOptions.get("emo_unk_token", "<|SPECIAL_TOKEN_1|>"), allowed_special="all"
+        )[0]
+        self.beam_search.emo_unk_score = 1
+        self.beam_search.emo_tokens = tokenizer.encode(
+            DecodingOptions.get("emo_target_tokens", "<|HAPPY|><|SAD|><|ANGRY|>"),
+            allowed_special="all",
+        )
+        self.beam_search.emo_scores = DecodingOptions.get("emo_target_threshold", [0.1, 0.1, 0.1])
+
+        self.beam_search.event_bg_token = tokenizer.encode(
+            DecodingOptions.get("gain_tokens_bg", "<|Speech|><|BGM|><|Applause|><|Laughter|>"),
+            allowed_special="all",
+        )
+        self.beam_search.event_ed_token = tokenizer.encode(
+            DecodingOptions.get("gain_tokens_ed", "<|/Speech|><|/BGM|><|/Applause|><|/Laughter|>"),
+            allowed_special="all",
+        )
+        self.beam_search.event_score_ga = DecodingOptions.get("gain_tokens_score", [1, 1, 1, 1])
+
         encoder_out, encoder_out_lens = self.encode(
             speech[None, :, :].permute(0, 2, 1), speech_lengths
         )

--
Gitblit v1.9.1