From 33a9e08dc9b65abc3f3e18d14253f95c79e0f749 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 08 五月 2024 19:20:43 +0800
Subject: [PATCH] dynamic batch
---
funasr/models/sense_voice/model.py | 42 ++++++++++++++++++++++++++++++++++++++++++
1 files changed, 42 insertions(+), 0 deletions(-)
diff --git a/funasr/models/sense_voice/model.py b/funasr/models/sense_voice/model.py
index 0230638..56e61e7 100644
--- a/funasr/models/sense_voice/model.py
+++ b/funasr/models/sense_voice/model.py
@@ -514,6 +514,27 @@
self.beam_search.sos = sos_int
self.beam_search.eos = eos_int[0]
+ # Paramterts for rich decoding
+ self.beam_search.emo_unk = tokenizer.encode(
+ DecodingOptions.get("emo_unk_token", "<|SPECIAL_TOKEN_1|>"), allowed_special="all"
+ )[0]
+ self.beam_search.emo_unk_score = 1
+ self.beam_search.emo_tokens = tokenizer.encode(
+ DecodingOptions.get("emo_target_tokens", "<|HAPPY|><|SAD|><|ANGRY|>"),
+ allowed_special="all",
+ )
+ self.beam_search.emo_scores = DecodingOptions.get("emo_target_threshold", [0.1, 0.1, 0.1])
+
+ self.beam_search.event_bg_token = tokenizer.encode(
+ DecodingOptions.get("gain_tokens_bg", "<|Speech|><|BGM|><|Applause|><|Laughter|>"),
+ allowed_special="all",
+ )
+ self.beam_search.event_ed_token = tokenizer.encode(
+ DecodingOptions.get("gain_tokens_ed", "<|/Speech|><|/BGM|><|/Applause|><|/Laughter|>"),
+ allowed_special="all",
+ )
+ self.beam_search.event_score_ga = DecodingOptions.get("gain_tokens_score", [1, 1, 1, 1])
+
encoder_out, encoder_out_lens = self.encode(
speech[None, :, :].permute(0, 2, 1), speech_lengths
)
@@ -843,6 +864,27 @@
self.beam_search.sos = sos_int
self.beam_search.eos = eos_int[0]
+ # Paramterts for rich decoding
+ self.beam_search.emo_unk = tokenizer.encode(
+ DecodingOptions.get("emo_unk_token", "<|SPECIAL_TOKEN_1|>"), allowed_special="all"
+ )[0]
+ self.beam_search.emo_unk_score = 1
+ self.beam_search.emo_tokens = tokenizer.encode(
+ DecodingOptions.get("emo_target_tokens", "<|HAPPY|><|SAD|><|ANGRY|>"),
+ allowed_special="all",
+ )
+ self.beam_search.emo_scores = DecodingOptions.get("emo_target_threshold", [0.1, 0.1, 0.1])
+
+ self.beam_search.event_bg_token = tokenizer.encode(
+ DecodingOptions.get("gain_tokens_bg", "<|Speech|><|BGM|><|Applause|><|Laughter|>"),
+ allowed_special="all",
+ )
+ self.beam_search.event_ed_token = tokenizer.encode(
+ DecodingOptions.get("gain_tokens_ed", "<|/Speech|><|/BGM|><|/Applause|><|/Laughter|>"),
+ allowed_special="all",
+ )
+ self.beam_search.event_score_ga = DecodingOptions.get("gain_tokens_score", [1, 1, 1, 1])
+
encoder_out, encoder_out_lens = self.encode(
speech[None, :, :].permute(0, 2, 1), speech_lengths
)
--
Gitblit v1.9.1