From e9d2cfc3a134b00f4e98271fbee3838d1ccecbcc Mon Sep 17 00:00:00 2001
From: VirtuosoQ <2416050435@qq.com>
Date: 星期五, 26 四月 2024 14:59:30 +0800
Subject: [PATCH] FunASR java http  client

---
 funasr/models/sense_voice/whisper_lib/decoding.py |  117 +++++++++++++++++++++++++++++++++++++++++++++++++++++++---
 1 files changed, 110 insertions(+), 7 deletions(-)

diff --git a/funasr/models/sense_voice/whisper_lib/decoding.py b/funasr/models/sense_voice/whisper_lib/decoding.py
index 49485d0..203efe8 100644
--- a/funasr/models/sense_voice/whisper_lib/decoding.py
+++ b/funasr/models/sense_voice/whisper_lib/decoding.py
@@ -10,6 +10,8 @@
 from .audio import CHUNK_LENGTH
 from .tokenizer import Tokenizer, get_tokenizer
 from .utils import compression_ratio
+from funasr.models.transformer.utils.nets_utils import to_device
+
 
 if TYPE_CHECKING:
     from .model import Whisper
@@ -17,7 +19,7 @@
 
 @torch.no_grad()
 def detect_language(
-    model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None
+    model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None, initial_prompt = None, x = None,
 ) -> Tuple[Tensor, List[dict]]:
     """
     Detect the spoken language in the audio, and return them as list of strings, along with the ids
@@ -48,24 +50,34 @@
         mel = mel.unsqueeze(0)
 
     # skip encoder forward pass if already-encoded audio features were given
-    if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state):
+    # FIX(funasr): sense vocie
+    if mel.shape[-1] != model.dims.n_audio_state:
         mel = model.encoder(mel)
 
     # forward pass using a single token, startoftranscript
     n_audio = mel.shape[0]
-    x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device)  # [n_audio, 1]
-    logits = model.logits(x, mel)[:, 0]
+    # FIX(funasr): sense vocie
+    # x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device)  # [n_audio, 1]
+    if x is None:
+        x = torch.tensor([tokenizer.encode(initial_prompt, allowed_special="all")] * n_audio).to(mel.device)  # [n_audio, 1]
 
+    else:
+        x = x.to(mel.device)
+
+    logits = model.logits(x[:,:-1], mel)[:, -1]
     # collect detected languages; suppress all non-language tokens
     mask = torch.ones(logits.shape[-1], dtype=torch.bool)
     mask[list(tokenizer.all_language_tokens)] = False
+    mask[tokenizer.no_speech] = False
+    
     logits[:, mask] = -np.inf
     language_tokens = logits.argmax(dim=-1)
     language_token_probs = logits.softmax(dim=-1).cpu()
+
     language_probs = [
         {
             c: language_token_probs[i, j].item()
-            for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes)
+            for j, c in zip(list(tokenizer.all_language_tokens) + [tokenizer.no_speech], list(tokenizer.all_language_codes) + ["nospeech"])
         }
         for i in range(n_audio)
     ]
@@ -106,12 +118,26 @@
     suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
     suppress_blank: bool = True  # this will suppress blank outputs
 
+    gain_event: bool = False  # this will suppress blank outputs
+    gain_tokens_bg: Optional[Union[str, List[int]]] = "<|Speech|><|BGM|><|Applause|><|Laughter|>"
+    gain_tokens_ed: Optional[Union[str, List[int]]] = "<|/Speech|><|/BGM|><|/Applause|><|/Laughter|>"
+    gain_tokens_score: List[float] = field(default_factory=lambda: [1, 1, 25.0, 5.0]) #[25, 5]
+
+    use_emo_threshold: bool = False  # this will suppress blank outputs
+    emo_unk_token: Optional[Union[str, List[int]]] = "<|SPECIAL_TOKEN_1|>"
+    emo_target_tokens: Optional[Union[str, List[int]]] = "<|HAPPY|><|SAD|><|ANGRY|>"
+    emo_target_threshold: List[float] = field(default_factory=lambda: [0.1, 0.1, 0.1]) #[25, 5]
+
     # timestamp sampling options
     without_timestamps: bool = False  # use <|notimestamps|> to sample text tokens only
     max_initial_timestamp: Optional[float] = 1.0
 
     # implementation details
     fp16: bool = True  # use fp16 for most of the calculation
+
+    # FIX(funasr): sense vocie
+    initial_prompt: str = None
+    vocab_path: str = None
 
 
 @dataclass(frozen=True)
@@ -437,6 +463,48 @@
     def apply(self, logits: Tensor, tokens: Tensor):
         logits[:, self.suppress_tokens] = -np.inf
 
+class GainEventToken(LogitFilter):
+    def __init__(self, bg_tokens: Sequence[int], ed_tokens:Sequence[int], gain_values: Sequence[float]):
+        self.bg_tokens = list(bg_tokens)
+        self.ed_tokens = list(ed_tokens)
+        self.gain_value = [np.log(max(ga, 1e-9)) for ga in gain_values]
+        assert len(self.ed_tokens) == len(self.gain_value)
+        assert len(self.bg_tokens) == len(self.gain_value)
+
+    def apply(self, logits: Tensor, tokens: Tensor):
+        for i in range(len(tokens)):
+            for bg, ed, ga in zip(self.bg_tokens, self.ed_tokens, self.gain_value):
+                sum_bg = sum([1 if x == bg else 0 for x in tokens[i]])
+                sum_ed = sum([1 if x == ed else 0 for x in tokens[i]])
+                logits[i, bg] += ga
+                if sum_bg > sum_ed or tokens[i,-1] in [bg, ed]:
+                    logits[i, bg] = -np.inf
+                if sum_bg <= sum_ed:
+                    logits[i, ed] = -np.inf
+
+class ThresholdEmoToken(LogitFilter):
+    def __init__(self, unk_tokens: Sequence[int], emo_tokens:Sequence[int], th_values: Sequence[float]):
+        self.unk_token = list(unk_tokens)[0]
+        self.emo_tokens = list(emo_tokens)
+        self.th_values = list(th_values)
+        assert len(self.emo_tokens) == len(self.th_values)
+
+    def apply(self, logits: Tensor, tokens: Tensor):
+        for i in range(len(tokens)):
+            for emo, th in zip(self.emo_tokens, self.th_values):
+                if logits[i].argmax() == emo and logits[i].softmax(dim=-1)[emo] < th:
+                    logits[i, self.unk_token] =  max(logits[i, emo], logits[i, self.unk_token])
+                    logits[i, emo] = -np.inf
+
+            # for bg, ed, ga in zip(self.bg_tokens, self.ed_tokens, self.gain_value):
+            #     sum_bg = sum([1 if x == bg else 0 for x in tokens[i]])
+            #     sum_ed = sum([1 if x == ed else 0 for x in tokens[i]])
+            #     logits[i, bg] += ga
+            #     if sum_bg > sum_ed or tokens[i,-1] in [bg, ed]:
+            #         logits[i, bg] = -np.inf
+            #     if sum_bg <= sum_ed:
+            #         logits[i, ed] = -np.inf
+
 
 class ApplyTimestampRules(LogitFilter):
     def __init__(
@@ -520,6 +588,7 @@
             num_languages=model.num_languages,
             language=language,
             task=options.task,
+            vocab_path=options.vocab_path
         )
         self.tokenizer: Tokenizer = tokenizer
         self.options: DecodingOptions = self._verify_options(options)
@@ -556,6 +625,20 @@
             self.logit_filters.append(SuppressBlank(self.tokenizer, self.sample_begin))
         if self.options.suppress_tokens:
             self.logit_filters.append(SuppressTokens(self._get_suppress_tokens()))
+        if self.options.gain_event:
+            self.logit_filters.append(GainEventToken(
+                self.tokenizer.encode(self.options.gain_tokens_bg, allowed_special="all"),
+                self.tokenizer.encode(self.options.gain_tokens_ed, allowed_special="all"),
+                self.options.gain_tokens_score
+                )
+            )
+        if self.options.use_emo_threshold:
+            self.logit_filters.append(ThresholdEmoToken(
+                self.tokenizer.encode(self.options.emo_unk_token, allowed_special="all"),
+                self.tokenizer.encode(self.options.emo_target_tokens, allowed_special="all"),
+                self.options.emo_target_threshold
+                )
+            )
         if not options.without_timestamps:
             precision = CHUNK_LENGTH / model.dims.n_audio_ctx  # usually 0.02 seconds
             max_initial_timestamp_index = None
@@ -609,6 +692,15 @@
                 + prompt_tokens[-(self.n_ctx // 2 - 1) :]
                 + tokens
             )
+        #FIX(funasr): sense vocie
+        if initial_prompt := self.options.initial_prompt:
+            if self.options.language is not None:
+                initial_prompt = f"{initial_prompt}<|{self.options.language}|>"
+                tokens = self.tokenizer.encode(initial_prompt, allowed_special="all")
+            else:
+                tokens = self.tokenizer.encode(initial_prompt, allowed_special="all")
+                tokens += [0]
+
 
         return tuple(tokens)
 
@@ -669,11 +761,22 @@
 
         if self.options.language is None or self.options.task == "lang_id":
             lang_tokens, lang_probs = self.model.detect_language(
-                audio_features, self.tokenizer
+                audio_features, self.tokenizer, x=tokens
             )
             languages = [max(probs, key=probs.get) for probs in lang_probs]
+            # FIX(funasr): sense vocie
+            # if self.options.language is None:
+                # tokens[:, self.sot_index + 1] = lang_tokens  # write language tokens
             if self.options.language is None:
-                tokens[:, self.sot_index + 1] = lang_tokens  # write language tokens
+                # tokens[:, self.sot_index + 1] = lang_tokens  # write language tokens
+                languages = "".join([f"<|{language}|>" for language in languages])
+
+                n_audio = audio_features.shape[0]
+                lang_tokens = torch.tensor([self.tokenizer.encode(languages, allowed_special="all")] * n_audio).to(
+                    audio_features.device)  # [n_audio, 1]
+                
+                tokens[:, -1:] = lang_tokens[:, :]
+                languages = [languages]
 
         return languages, lang_probs
 

--
Gitblit v1.9.1