From fce4e1d1b48f23cd8332e60afce3df8d6209a6a7 Mon Sep 17 00:00:00 2001
From: gaochangfeng <54253717+gaochangfeng@users.noreply.github.com>
Date: 星期四, 11 四月 2024 14:59:22 +0800
Subject: [PATCH] SenseVoice对富文本解码的参数 (#1608)

---
 funasr/models/sense_voice/whisper_lib/decoding.py |   66 +++++++++++++++++++++++++++++++++
 1 files changed, 66 insertions(+), 0 deletions(-)

diff --git a/funasr/models/sense_voice/whisper_lib/decoding.py b/funasr/models/sense_voice/whisper_lib/decoding.py
index 2239b64..62be3bc 100644
--- a/funasr/models/sense_voice/whisper_lib/decoding.py
+++ b/funasr/models/sense_voice/whisper_lib/decoding.py
@@ -118,6 +118,16 @@
     suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
     suppress_blank: bool = True  # this will suppress blank outputs
 
+    gain_event: bool = False  # this will suppress blank outputs
+    gain_tokens_bg: Optional[Union[str, List[int]]] = "<|Applause|><|Laughter|>"
+    gain_tokens_ed: Optional[Union[str, List[int]]] = "<|/Applause|><|/Laughter|>"
+    gain_tokens_score: List[float] = field(default_factory=lambda: [25.0, 5.0]) #[25, 5]
+
+    use_emo_threshold: bool = False  # this will suppress blank outputs
+    emo_unk_token: Optional[Union[str, List[int]]] = "<|SPECIAL_TOKEN_1|>"
+    emo_target_tokens: Optional[Union[str, List[int]]] = "<|HAPPY|><|SAD|><|ANGRY|>"
+    emo_target_threshold: List[float] = field(default_factory=lambda: [0.1, 0.1, 0.1]) #[25, 5]
+
     # timestamp sampling options
     without_timestamps: bool = False  # use <|notimestamps|> to sample text tokens only
     max_initial_timestamp: Optional[float] = 1.0
@@ -453,6 +463,48 @@
     def apply(self, logits: Tensor, tokens: Tensor):
         logits[:, self.suppress_tokens] = -np.inf
 
+class GainEventToken(LogitFilter):
+    def __init__(self, bg_tokens: Sequence[int], ed_tokens:Sequence[int], gain_values: Sequence[float]):
+        self.bg_tokens = list(bg_tokens)
+        self.ed_tokens = list(ed_tokens)
+        self.gain_value = [np.log(max(ga, 1e-9)) for ga in gain_values]
+        assert len(self.ed_tokens) == len(self.gain_value)
+        assert len(self.bg_tokens) == len(self.gain_value)
+
+    def apply(self, logits: Tensor, tokens: Tensor):
+        for i in range(len(tokens)):
+            for bg, ed, ga in zip(self.bg_tokens, self.ed_tokens, self.gain_value):
+                sum_bg = sum([1 if x == bg else 0 for x in tokens[i]])
+                sum_ed = sum([1 if x == ed else 0 for x in tokens[i]])
+                logits[i, bg] += ga
+                if sum_bg > sum_ed or tokens[i,-1] in [bg, ed]:
+                    logits[i, bg] = -np.inf
+                if sum_bg <= sum_ed:
+                    logits[i, ed] = -np.inf
+
+class ThresholdEmoToken(LogitFilter):
+    def __init__(self, unk_tokens: Sequence[int], emo_tokens:Sequence[int], th_values: Sequence[float]):
+        self.unk_token = list(unk_tokens)[0]
+        self.emo_tokens = list(emo_tokens)
+        self.th_values = list(th_values)
+        assert len(self.emo_tokens) == len(self.th_values)
+
+    def apply(self, logits: Tensor, tokens: Tensor):
+        for i in range(len(tokens)):
+            for emo, th in zip(self.emo_tokens, self.th_values):
+                if logits[i].argmax() == emo and logits[i].softmax(dim=-1)[emo] < th:
+                    logits[i, self.unk_token] =  max(logits[i, emo], logits[i, self.unk_token])
+                    logits[i, emo] = -np.inf
+
+            # for bg, ed, ga in zip(self.bg_tokens, self.ed_tokens, self.gain_value):
+            #     sum_bg = sum([1 if x == bg else 0 for x in tokens[i]])
+            #     sum_ed = sum([1 if x == ed else 0 for x in tokens[i]])
+            #     logits[i, bg] += ga
+            #     if sum_bg > sum_ed or tokens[i,-1] in [bg, ed]:
+            #         logits[i, bg] = -np.inf
+            #     if sum_bg <= sum_ed:
+            #         logits[i, ed] = -np.inf
+
 
 class ApplyTimestampRules(LogitFilter):
     def __init__(
@@ -573,6 +625,20 @@
             self.logit_filters.append(SuppressBlank(self.tokenizer, self.sample_begin))
         if self.options.suppress_tokens:
             self.logit_filters.append(SuppressTokens(self._get_suppress_tokens()))
+        if self.options.gain_event:
+            self.logit_filters.append(GainEventToken(
+                self.tokenizer.encode(self.options.gain_tokens_bg, allowed_special="all"),
+                self.tokenizer.encode(self.options.gain_tokens_ed, allowed_special="all"),
+                self.options.gain_tokens_score
+                )
+            )
+        if self.options.use_emo_threshold:
+            self.logit_filters.append(ThresholdEmoToken(
+                self.tokenizer.encode(self.options.emo_unk_token, allowed_special="all"),
+                self.tokenizer.encode(self.options.emo_target_tokens, allowed_special="all"),
+                self.options.emo_target_threshold
+                )
+            )
         if not options.without_timestamps:
             precision = CHUNK_LENGTH / model.dims.n_audio_ctx  # usually 0.02 seconds
             max_initial_timestamp_index = None

--
Gitblit v1.9.1