From 95bed2337e8065d3331109d6c2d00349ad82fd77 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 08 五月 2024 19:14:21 +0800
Subject: [PATCH] Merge branch 'dev_gzf_exp' of github.com:alibaba-damo-academy/FunASR into dev_gzf_exp merge
---
funasr/models/sense_voice/model.py | 28 ++++++++++++++++++++++++++++
1 files changed, 28 insertions(+), 0 deletions(-)
diff --git a/funasr/models/sense_voice/model.py b/funasr/models/sense_voice/model.py
index 0230638..00bc85b 100644
--- a/funasr/models/sense_voice/model.py
+++ b/funasr/models/sense_voice/model.py
@@ -514,6 +514,20 @@
self.beam_search.sos = sos_int
self.beam_search.eos = eos_int[0]
+ # Paramterts for rich decoding
+ self.beam_search.emo_unk = tokenizer.encode(
+ DecodingOptions.get("emo_unk_token", "<|SPECIAL_TOKEN_1|>"), allowed_special="all")[0]
+ self.beam_search.emo_unk_score = 1
+ self.beam_search.emo_tokens = tokenizer.encode(
+ DecodingOptions.get("emo_target_tokens", "<|HAPPY|><|SAD|><|ANGRY|>"), allowed_special="all")
+ self.beam_search.emo_scores = DecodingOptions.get("emo_target_threshold", [0.1, 0.1, 0.1])
+
+ self.beam_search.event_bg_token = tokenizer.encode(
+ DecodingOptions.get("gain_tokens_bg", "<|Speech|><|BGM|><|Applause|><|Laughter|>"), allowed_special="all")
+ self.beam_search.event_ed_token = tokenizer.encode(
+ DecodingOptions.get("gain_tokens_ed", "<|/Speech|><|/BGM|><|/Applause|><|/Laughter|>"), allowed_special="all")
+ self.beam_search.event_score_ga = DecodingOptions.get("gain_tokens_score", [1, 1, 1, 1])
+
encoder_out, encoder_out_lens = self.encode(
speech[None, :, :].permute(0, 2, 1), speech_lengths
)
@@ -843,6 +857,20 @@
self.beam_search.sos = sos_int
self.beam_search.eos = eos_int[0]
+ # Paramterts for rich decoding
+ self.beam_search.emo_unk = tokenizer.encode(
+ DecodingOptions.get("emo_unk_token", "<|SPECIAL_TOKEN_1|>"), allowed_special="all")[0]
+ self.beam_search.emo_unk_score = 1
+ self.beam_search.emo_tokens = tokenizer.encode(
+ DecodingOptions.get("emo_target_tokens", "<|HAPPY|><|SAD|><|ANGRY|>"), allowed_special="all")
+ self.beam_search.emo_scores = DecodingOptions.get("emo_target_threshold", [0.1, 0.1, 0.1])
+
+ self.beam_search.event_bg_token = tokenizer.encode(
+ DecodingOptions.get("gain_tokens_bg", "<|Speech|><|BGM|><|Applause|><|Laughter|>"), allowed_special="all")
+ self.beam_search.event_ed_token = tokenizer.encode(
+ DecodingOptions.get("gain_tokens_ed", "<|/Speech|><|/BGM|><|/Applause|><|/Laughter|>"), allowed_special="all")
+ self.beam_search.event_score_ga = DecodingOptions.get("gain_tokens_score", [1, 1, 1, 1])
+
encoder_out, encoder_out_lens = self.encode(
speech[None, :, :].permute(0, 2, 1), speech_lengths
)
--
Gitblit v1.9.1