From c553a8db1712c2a5deeef5bbb68bd1fdf8d61ab7 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 13 六月 2024 17:38:01 +0800
Subject: [PATCH] decoding

---
 funasr/models/sense_voice/search.py |   60 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 1 files changed, 60 insertions(+), 0 deletions(-)

diff --git a/funasr/models/sense_voice/search.py b/funasr/models/sense_voice/search.py
index 98d02db..3a1a049 100644
--- a/funasr/models/sense_voice/search.py
+++ b/funasr/models/sense_voice/search.py
@@ -1,4 +1,5 @@
 from itertools import chain
+from dataclasses import field
 import logging
 from typing import Any
 from typing import Dict
@@ -8,6 +9,7 @@
 from typing import Union
 
 import torch
+import numpy as np
 
 from funasr.metrics.common import end_detect
 from funasr.models.transformer.scorers.scorer_interface import PartialScorerInterface
@@ -42,6 +44,16 @@
         vocab_size: int,
         sos=None,
         eos=None,
+        # NOTE add rich decoding parameters
+        # [SPECIAL_TOKEN_1, HAPPY, SAD, ANGRY, NEUTRAL]
+        emo_unk: int = 58964,
+        emo_unk_score: float = 1.0,
+        emo_tokens: List[int] = field(default_factory=lambda: [58954, 58955, 58956, 58957]),
+        emo_scores: List[float] = field(default_factory=lambda: [0.1, 0.1, 0.1, 0.1]),
+        # [Speech, BGM, Laughter, Applause]
+        event_bg_token: List[int] = field(default_factory=lambda: [58946, 58948, 58950, 58952]),
+        event_ed_token: List[int] = field(default_factory=lambda: [58947, 58949, 58951, 58953]),
+        event_score_ga: List[float] = field(default_factory=lambda: [1, 1, 5, 25]),
         token_list: List[str] = None,
         pre_beam_ratio: float = 1.5,
         pre_beam_score_key: str = None,
@@ -110,6 +122,14 @@
             and len(self.part_scorers) > 0
         )
 
+        self.emo_unk = emo_unk
+        self.emo_unk_score = emo_unk_score
+        self.emo_tokens = emo_tokens
+        self.emo_scores = emo_scores
+        self.event_bg_token = event_bg_token
+        self.event_ed_token = event_ed_token
+        self.event_score_ga = event_score_ga
+
     def init_hyp(self, x: torch.Tensor) -> List[Hypothesis]:
         """Get an initial hypothesis data.
 
@@ -170,8 +190,46 @@
         """
         scores = dict()
         states = dict()
+
+        def get_score(yseq, sp1, sp2):
+            score = [0, 0]
+            last_token = yseq[-1]
+            last_token2 = yseq[-2] if len(yseq) > 1 else yseq[-1]
+            sum_sp1 = sum([1 if x == sp1 else 0 for x in yseq])
+            sum_sp2 = sum([1 if x == sp2 else 0 for x in yseq])
+            if sum_sp1 > sum_sp2 or last_token in [sp1, sp2]:
+                score[0] = -np.inf
+            if sum_sp2 >= sum_sp1:
+                score[1] = -np.inf
+            return score
+
+        def struct_score(yseq, score):
+            import math
+
+            last_token = yseq[-1]
+            if last_token in self.emo_tokens + [self.emo_unk]:
+                # prevent output event after emotation token
+                score[self.event_bg_token] = -np.inf
+
+            for eve_bg, eve_ed, eve_ga in zip(
+                self.event_bg_token, self.event_ed_token, self.event_score_ga
+            ):
+                score_offset = get_score(yseq, eve_bg, eve_ed)
+                score[eve_bg] += score_offset[0]
+                score[eve_ed] += score_offset[1]
+                score[eve_bg] += math.log(eve_ga)
+
+            score[self.emo_unk] += math.log(self.emo_unk_score)
+            for emo, emo_th in zip(self.emo_tokens, self.emo_scores):
+                if score.argmax() == emo and score[emo] < math.log(emo_th):
+                    score[self.emo_unk] = max(score[emo], score[self.emo_unk])
+                    score[emo] = -np.inf
+            return score
+
         for k, d in self.full_scorers.items():
             scores[k], states[k] = d.score(hyp.yseq, hyp.states[k], x)
+            scores[k] = struct_score(hyp.yseq, scores[k])
+
         return scores, states
 
     def score_partial(
@@ -370,6 +428,8 @@
             # post process of one iteration
             running_hyps = self.post_process(i, maxlen, maxlenratio, best, ended_hyps)
             # end detection
+            # if len(ended_hyps) > 0:
+            #     print(f"ended_hyps: {ended_hyps}")
             if maxlenratio == 0.0 and end_detect([h.asdict() for h in ended_hyps], i):
                 logging.info(f"end detected at {i}")
                 break

--
Gitblit v1.9.1