From fce4e1d1b48f23cd8332e60afce3df8d6209a6a7 Mon Sep 17 00:00:00 2001
From: gaochangfeng <54253717+gaochangfeng@users.noreply.github.com>
Date: 星期四, 11 四月 2024 14:59:22 +0800
Subject: [PATCH] SenseVoice对富文本解码的参数 (#1608)
---
funasr/models/sense_voice/whisper_lib/decoding.py | 66 +++++++++++++++++++++++++++++++++
1 files changed, 66 insertions(+), 0 deletions(-)
diff --git a/funasr/models/sense_voice/whisper_lib/decoding.py b/funasr/models/sense_voice/whisper_lib/decoding.py
index 2239b64..62be3bc 100644
--- a/funasr/models/sense_voice/whisper_lib/decoding.py
+++ b/funasr/models/sense_voice/whisper_lib/decoding.py
@@ -118,6 +118,16 @@
suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
suppress_blank: bool = True # this will suppress blank outputs
+ gain_event: bool = False # this will suppress blank outputs
+ gain_tokens_bg: Optional[Union[str, List[int]]] = "<|Applause|><|Laughter|>"
+ gain_tokens_ed: Optional[Union[str, List[int]]] = "<|/Applause|><|/Laughter|>"
+ gain_tokens_score: List[float] = field(default_factory=lambda: [25.0, 5.0]) #[25, 5]
+
+ use_emo_threshold: bool = False # this will suppress blank outputs
+ emo_unk_token: Optional[Union[str, List[int]]] = "<|SPECIAL_TOKEN_1|>"
+ emo_target_tokens: Optional[Union[str, List[int]]] = "<|HAPPY|><|SAD|><|ANGRY|>"
+ emo_target_threshold: List[float] = field(default_factory=lambda: [0.1, 0.1, 0.1]) #[25, 5]
+
# timestamp sampling options
without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only
max_initial_timestamp: Optional[float] = 1.0
@@ -453,6 +463,48 @@
def apply(self, logits: Tensor, tokens: Tensor):
logits[:, self.suppress_tokens] = -np.inf
+class GainEventToken(LogitFilter):
+ def __init__(self, bg_tokens: Sequence[int], ed_tokens:Sequence[int], gain_values: Sequence[float]):
+ self.bg_tokens = list(bg_tokens)
+ self.ed_tokens = list(ed_tokens)
+ self.gain_value = [np.log(max(ga, 1e-9)) for ga in gain_values]
+ assert len(self.ed_tokens) == len(self.gain_value)
+ assert len(self.bg_tokens) == len(self.gain_value)
+
+ def apply(self, logits: Tensor, tokens: Tensor):
+ for i in range(len(tokens)):
+ for bg, ed, ga in zip(self.bg_tokens, self.ed_tokens, self.gain_value):
+ sum_bg = sum([1 if x == bg else 0 for x in tokens[i]])
+ sum_ed = sum([1 if x == ed else 0 for x in tokens[i]])
+ logits[i, bg] += ga
+ if sum_bg > sum_ed or tokens[i,-1] in [bg, ed]:
+ logits[i, bg] = -np.inf
+ if sum_bg <= sum_ed:
+ logits[i, ed] = -np.inf
+
+class ThresholdEmoToken(LogitFilter):
+ def __init__(self, unk_tokens: Sequence[int], emo_tokens:Sequence[int], th_values: Sequence[float]):
+ self.unk_token = list(unk_tokens)[0]
+ self.emo_tokens = list(emo_tokens)
+ self.th_values = list(th_values)
+ assert len(self.emo_tokens) == len(self.th_values)
+
+ def apply(self, logits: Tensor, tokens: Tensor):
+ for i in range(len(tokens)):
+ for emo, th in zip(self.emo_tokens, self.th_values):
+ if logits[i].argmax() == emo and logits[i].softmax(dim=-1)[emo] < th:
+ logits[i, self.unk_token] = max(logits[i, emo], logits[i, self.unk_token])
+ logits[i, emo] = -np.inf
+
+ # for bg, ed, ga in zip(self.bg_tokens, self.ed_tokens, self.gain_value):
+ # sum_bg = sum([1 if x == bg else 0 for x in tokens[i]])
+ # sum_ed = sum([1 if x == ed else 0 for x in tokens[i]])
+ # logits[i, bg] += ga
+ # if sum_bg > sum_ed or tokens[i,-1] in [bg, ed]:
+ # logits[i, bg] = -np.inf
+ # if sum_bg <= sum_ed:
+ # logits[i, ed] = -np.inf
+
class ApplyTimestampRules(LogitFilter):
def __init__(
@@ -573,6 +625,20 @@
self.logit_filters.append(SuppressBlank(self.tokenizer, self.sample_begin))
if self.options.suppress_tokens:
self.logit_filters.append(SuppressTokens(self._get_suppress_tokens()))
+ if self.options.gain_event:
+ self.logit_filters.append(GainEventToken(
+ self.tokenizer.encode(self.options.gain_tokens_bg, allowed_special="all"),
+ self.tokenizer.encode(self.options.gain_tokens_ed, allowed_special="all"),
+ self.options.gain_tokens_score
+ )
+ )
+ if self.options.use_emo_threshold:
+ self.logit_filters.append(ThresholdEmoToken(
+ self.tokenizer.encode(self.options.emo_unk_token, allowed_special="all"),
+ self.tokenizer.encode(self.options.emo_target_tokens, allowed_special="all"),
+ self.options.emo_target_threshold
+ )
+ )
if not options.without_timestamps:
precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds
max_initial_timestamp_index = None
--
Gitblit v1.9.1