From fb0da9f849a5d3bd473dcdbaf6197c6a5ff24a57 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 07 五月 2024 15:53:26 +0800
Subject: [PATCH] decoding key

---
 funasr/models/sense_voice/whisper_lib/decoding.py |  203 ++++++++++++++++++++------------------------------
 1 files changed, 82 insertions(+), 121 deletions(-)

diff --git a/funasr/models/sense_voice/whisper_lib/decoding.py b/funasr/models/sense_voice/whisper_lib/decoding.py
index 62be3bc..609d6a6 100644
--- a/funasr/models/sense_voice/whisper_lib/decoding.py
+++ b/funasr/models/sense_voice/whisper_lib/decoding.py
@@ -19,7 +19,11 @@
 
 @torch.no_grad()
 def detect_language(
-    model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None, initial_prompt = None, x = None,
+    model: "Whisper",
+    mel: Tensor,
+    tokenizer: Tokenizer = None,
+    initial_prompt=None,
+    x=None,
 ) -> Tuple[Tensor, List[dict]]:
     """
     Detect the spoken language in the audio, and return them as list of strings, along with the ids
@@ -34,16 +38,9 @@
         list of dictionaries containing the probability distribution over all languages.
     """
     if tokenizer is None:
-        tokenizer = get_tokenizer(
-            model.is_multilingual, num_languages=model.num_languages
-        )
-    if (
-        tokenizer.language is None
-        or tokenizer.language_token not in tokenizer.sot_sequence
-    ):
-        raise ValueError(
-            "This model doesn't have language tokens so it can't perform lang id"
-        )
+        tokenizer = get_tokenizer(model.is_multilingual, num_languages=model.num_languages)
+    if tokenizer.language is None or tokenizer.language_token not in tokenizer.sot_sequence:
+        raise ValueError("This model doesn't have language tokens so it can't perform lang id")
 
     single = mel.ndim == 2
     if single:
@@ -59,17 +56,21 @@
     # FIX(funasr): sense vocie
     # x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device)  # [n_audio, 1]
     if x is None:
-        x = torch.tensor([tokenizer.encode(initial_prompt, allowed_special="all")] * n_audio).to(mel.device)  # [n_audio, 1]
+        x = torch.tensor([tokenizer.encode(initial_prompt, allowed_special="all")] * n_audio).to(
+            mel.device
+        )  # [n_audio, 1]
 
     else:
         x = x.to(mel.device)
+    # FIX(funasr): sense vocie
+    # logits = model.logits(x[:, :-1], mel)[:, -1]
+    logits = model.logits(x[:, :], mel)[:, -1]
 
-    logits = model.logits(x[:,:-1], mel)[:, -1]
     # collect detected languages; suppress all non-language tokens
     mask = torch.ones(logits.shape[-1], dtype=torch.bool)
     mask[list(tokenizer.all_language_tokens)] = False
     mask[tokenizer.no_speech] = False
-    
+
     logits[:, mask] = -np.inf
     language_tokens = logits.argmax(dim=-1)
     language_token_probs = logits.softmax(dim=-1).cpu()
@@ -77,7 +78,10 @@
     language_probs = [
         {
             c: language_token_probs[i, j].item()
-            for j, c in zip(list(tokenizer.all_language_tokens) + [tokenizer.no_speech], list(tokenizer.all_language_codes) + ["nospeech"])
+            for j, c in zip(
+                list(tokenizer.all_language_tokens) + [tokenizer.no_speech],
+                list(tokenizer.all_language_codes) + ["nospeech"],
+            )
         }
         for i in range(n_audio)
     ]
@@ -119,14 +123,16 @@
     suppress_blank: bool = True  # this will suppress blank outputs
 
     gain_event: bool = False  # this will suppress blank outputs
-    gain_tokens_bg: Optional[Union[str, List[int]]] = "<|Applause|><|Laughter|>"
-    gain_tokens_ed: Optional[Union[str, List[int]]] = "<|/Applause|><|/Laughter|>"
-    gain_tokens_score: List[float] = field(default_factory=lambda: [25.0, 5.0]) #[25, 5]
+    gain_tokens_bg: Optional[Union[str, List[int]]] = "<|Speech|><|BGM|><|Applause|><|Laughter|>"
+    gain_tokens_ed: Optional[Union[str, List[int]]] = (
+        "<|/Speech|><|/BGM|><|/Applause|><|/Laughter|>"
+    )
+    gain_tokens_score: List[float] = field(default_factory=lambda: [1, 1, 25.0, 5.0])  # [25, 5]
 
     use_emo_threshold: bool = False  # this will suppress blank outputs
     emo_unk_token: Optional[Union[str, List[int]]] = "<|SPECIAL_TOKEN_1|>"
     emo_target_tokens: Optional[Union[str, List[int]]] = "<|HAPPY|><|SAD|><|ANGRY|>"
-    emo_target_threshold: List[float] = field(default_factory=lambda: [0.1, 0.1, 0.1]) #[25, 5]
+    emo_target_threshold: List[float] = field(default_factory=lambda: [0.1, 0.1, 0.1])  # [25, 5]
 
     # timestamp sampling options
     without_timestamps: bool = False  # use <|notimestamps|> to sample text tokens only
@@ -203,9 +209,7 @@
 
 
 class SequenceRanker:
-    def rank(
-        self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]
-    ) -> List[int]:
+    def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]) -> List[int]:
         """
         Given a list of groups of samples and their cumulative log probabilities,
         return the indices of the samples in each group to select as the final result
@@ -243,9 +247,7 @@
     def reset(self):
         """Initialize any stateful variables for decoding a new sequence"""
 
-    def update(
-        self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
-    ) -> Tuple[Tensor, bool]:
+    def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
         """Specify how to select the next token, based on the current trace and logits
 
         Parameters
@@ -300,9 +302,7 @@
         self.temperature = temperature
         self.eot = eot
 
-    def update(
-        self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
-    ) -> Tuple[Tensor, bool]:
+    def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
         if self.temperature == 0:
             next_tokens = logits.argmax(dim=-1)
         else:
@@ -339,16 +339,12 @@
         self.max_candidates: int = round(beam_size * self.patience)
         self.finished_sequences = None
 
-        assert (
-            self.max_candidates > 0
-        ), f"Invalid beam size ({beam_size}) or patience ({patience})"
+        assert self.max_candidates > 0, f"Invalid beam size ({beam_size}) or patience ({patience})"
 
     def reset(self):
         self.finished_sequences = None
 
-    def update(
-        self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
-    ) -> Tuple[Tensor, bool]:
+    def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
         if tokens.shape[0] % self.beam_size != 0:
             raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
 
@@ -392,9 +388,7 @@
 
         # add newly finished sequences to self.finished_sequences
         assert len(self.finished_sequences) == len(finished_sequences)
-        for previously_finished, newly_finished in zip(
-            self.finished_sequences, finished_sequences
-        ):
+        for previously_finished, newly_finished in zip(self.finished_sequences, finished_sequences):
             for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
                 if len(previously_finished) >= self.max_candidates:
                     break  # the candidate list is full
@@ -402,8 +396,7 @@
 
         # mark as completed if all audio has enough number of samples
         completed = all(
-            len(sequences) >= self.max_candidates
-            for sequences in self.finished_sequences
+            len(sequences) >= self.max_candidates for sequences in self.finished_sequences
         )
         return tokens, completed
 
@@ -411,9 +404,7 @@
         # collect all finished sequences, including patience, and add unfinished ones if not enough
         sum_logprobs = sum_logprobs.cpu()
         for i, sequences in enumerate(self.finished_sequences):
-            if (
-                len(sequences) < self.beam_size
-            ):  # when not enough sequences are finished
+            if len(sequences) < self.beam_size:  # when not enough sequences are finished
                 for j in list(np.argsort(sum_logprobs[i]))[::-1]:
                     sequence = preceding_tokens[i, j].tolist() + [self.eot]
                     sequences[tuple(sequence)] = sum_logprobs[i][j].item()
@@ -421,8 +412,7 @@
                         break
 
         tokens: List[List[Tensor]] = [
-            [torch.tensor(seq) for seq in sequences.keys()]
-            for sequences in self.finished_sequences
+            [torch.tensor(seq) for seq in sequences.keys()] for sequences in self.finished_sequences
         ]
         sum_logprobs: List[List[float]] = [
             list(sequences.values()) for sequences in self.finished_sequences
@@ -463,8 +453,11 @@
     def apply(self, logits: Tensor, tokens: Tensor):
         logits[:, self.suppress_tokens] = -np.inf
 
+
 class GainEventToken(LogitFilter):
-    def __init__(self, bg_tokens: Sequence[int], ed_tokens:Sequence[int], gain_values: Sequence[float]):
+    def __init__(
+        self, bg_tokens: Sequence[int], ed_tokens: Sequence[int], gain_values: Sequence[float]
+    ):
         self.bg_tokens = list(bg_tokens)
         self.ed_tokens = list(ed_tokens)
         self.gain_value = [np.log(max(ga, 1e-9)) for ga in gain_values]
@@ -477,13 +470,16 @@
                 sum_bg = sum([1 if x == bg else 0 for x in tokens[i]])
                 sum_ed = sum([1 if x == ed else 0 for x in tokens[i]])
                 logits[i, bg] += ga
-                if sum_bg > sum_ed or tokens[i,-1] in [bg, ed]:
+                if sum_bg > sum_ed or tokens[i, -1] in [bg, ed]:
                     logits[i, bg] = -np.inf
                 if sum_bg <= sum_ed:
                     logits[i, ed] = -np.inf
 
+
 class ThresholdEmoToken(LogitFilter):
-    def __init__(self, unk_tokens: Sequence[int], emo_tokens:Sequence[int], th_values: Sequence[float]):
+    def __init__(
+        self, unk_tokens: Sequence[int], emo_tokens: Sequence[int], th_values: Sequence[float]
+    ):
         self.unk_token = list(unk_tokens)[0]
         self.emo_tokens = list(emo_tokens)
         self.th_values = list(th_values)
@@ -493,7 +489,7 @@
         for i in range(len(tokens)):
             for emo, th in zip(self.emo_tokens, self.th_values):
                 if logits[i].argmax() == emo and logits[i].softmax(dim=-1)[emo] < th:
-                    logits[i, self.unk_token] =  max(logits[i, emo], logits[i, self.unk_token])
+                    logits[i, self.unk_token] = max(logits[i, emo], logits[i, self.unk_token])
                     logits[i, emo] = -np.inf
 
             # for bg, ed, ga in zip(self.bg_tokens, self.ed_tokens, self.gain_value):
@@ -526,12 +522,8 @@
         for k in range(tokens.shape[0]):
             sampled_tokens = tokens[k, self.sample_begin :]
             seq = [t for t in sampled_tokens.tolist()]
-            last_was_timestamp = (
-                len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
-            )
-            penultimate_was_timestamp = (
-                len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
-            )
+            last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
+            penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
 
             if last_was_timestamp:
                 if penultimate_was_timestamp:  # has to be non-timestamp
@@ -539,9 +531,7 @@
                 else:  # cannot be normal text tokens
                     logits[k, : self.tokenizer.eot] = -np.inf
 
-            timestamps = sampled_tokens[
-                sampled_tokens.ge(self.tokenizer.timestamp_begin)
-            ]
+            timestamps = sampled_tokens[sampled_tokens.ge(self.tokenizer.timestamp_begin)]
             if timestamps.numel() > 0:
                 # timestamps shouldn't decrease; forbid timestamp tokens smaller than the last
                 # also force each segment to have a nonzero length, to prevent infinite looping
@@ -557,17 +547,13 @@
 
             # apply the `max_initial_timestamp` option
             if self.max_initial_timestamp_index is not None:
-                last_allowed = (
-                    self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
-                )
+                last_allowed = self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
                 logits[:, last_allowed + 1 :] = -np.inf
 
         # if sum of probability over timestamps is above any other token, sample timestamp
         logprobs = F.log_softmax(logits.float(), dim=-1)
         for k in range(tokens.shape[0]):
-            timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(
-                dim=-1
-            )
+            timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(dim=-1)
             max_text_token_logprob = logprobs[k, : self.tokenizer.timestamp_begin].max()
             if timestamp_logprob > max_text_token_logprob:
                 logits[k, : self.tokenizer.timestamp_begin] = -np.inf
@@ -588,7 +574,7 @@
             num_languages=model.num_languages,
             language=language,
             task=options.task,
-            vocab_path=options.vocab_path
+            vocab_path=options.vocab_path,
         )
         self.tokenizer: Tokenizer = tokenizer
         self.options: DecodingOptions = self._verify_options(options)
@@ -626,30 +612,28 @@
         if self.options.suppress_tokens:
             self.logit_filters.append(SuppressTokens(self._get_suppress_tokens()))
         if self.options.gain_event:
-            self.logit_filters.append(GainEventToken(
-                self.tokenizer.encode(self.options.gain_tokens_bg, allowed_special="all"),
-                self.tokenizer.encode(self.options.gain_tokens_ed, allowed_special="all"),
-                self.options.gain_tokens_score
+            self.logit_filters.append(
+                GainEventToken(
+                    self.tokenizer.encode(self.options.gain_tokens_bg, allowed_special="all"),
+                    self.tokenizer.encode(self.options.gain_tokens_ed, allowed_special="all"),
+                    self.options.gain_tokens_score,
                 )
             )
         if self.options.use_emo_threshold:
-            self.logit_filters.append(ThresholdEmoToken(
-                self.tokenizer.encode(self.options.emo_unk_token, allowed_special="all"),
-                self.tokenizer.encode(self.options.emo_target_tokens, allowed_special="all"),
-                self.options.emo_target_threshold
+            self.logit_filters.append(
+                ThresholdEmoToken(
+                    self.tokenizer.encode(self.options.emo_unk_token, allowed_special="all"),
+                    self.tokenizer.encode(self.options.emo_target_tokens, allowed_special="all"),
+                    self.options.emo_target_threshold,
                 )
             )
         if not options.without_timestamps:
             precision = CHUNK_LENGTH / model.dims.n_audio_ctx  # usually 0.02 seconds
             max_initial_timestamp_index = None
             if options.max_initial_timestamp:
-                max_initial_timestamp_index = round(
-                    self.options.max_initial_timestamp / precision
-                )
+                max_initial_timestamp_index = round(self.options.max_initial_timestamp / precision)
             self.logit_filters.append(
-                ApplyTimestampRules(
-                    tokenizer, self.sample_begin, max_initial_timestamp_index
-                )
+                ApplyTimestampRules(tokenizer, self.sample_begin, max_initial_timestamp_index)
             )
 
     def _verify_options(self, options: DecodingOptions) -> DecodingOptions:
@@ -660,9 +644,7 @@
                 raise ValueError("best_of with greedy sampling (T=0) is not compatible")
         if options.patience is not None and options.beam_size is None:
             raise ValueError("patience requires beam_size to be given")
-        if options.length_penalty is not None and not (
-            0 <= options.length_penalty <= 1
-        ):
+        if options.length_penalty is not None and not (0 <= options.length_penalty <= 1):
             raise ValueError("length_penalty (alpha) should be a value between 0 and 1")
 
         return options
@@ -672,9 +654,7 @@
 
         if prefix := self.options.prefix:
             prefix_tokens = (
-                self.tokenizer.encode(" " + prefix.strip())
-                if isinstance(prefix, str)
-                else prefix
+                self.tokenizer.encode(" " + prefix.strip()) if isinstance(prefix, str) else prefix
             )
             if self.sample_len is not None:
                 max_prefix_len = self.n_ctx // 2 - self.sample_len
@@ -683,16 +663,10 @@
 
         if prompt := self.options.prompt:
             prompt_tokens = (
-                self.tokenizer.encode(" " + prompt.strip())
-                if isinstance(prompt, str)
-                else prompt
+                self.tokenizer.encode(" " + prompt.strip()) if isinstance(prompt, str) else prompt
             )
-            tokens = (
-                [self.tokenizer.sot_prev]
-                + prompt_tokens[-(self.n_ctx // 2 - 1) :]
-                + tokens
-            )
-        #FIX(funasr): sense vocie
+            tokens = [self.tokenizer.sot_prev] + prompt_tokens[-(self.n_ctx // 2 - 1) :] + tokens
+        # FIX(funasr): sense vocie
         if initial_prompt := self.options.initial_prompt:
             if self.options.language is not None:
                 initial_prompt = f"{initial_prompt}<|{self.options.language}|>"
@@ -700,7 +674,6 @@
             else:
                 tokens = self.tokenizer.encode(initial_prompt, allowed_special="all")
                 tokens += [0]
-
 
         return tuple(tokens)
 
@@ -746,12 +719,8 @@
         else:
             audio_features = self.model.encoder(mel)
 
-        if audio_features.dtype != (
-            torch.float16 if self.options.fp16 else torch.float32
-        ):
-            return TypeError(
-                f"audio_features has an incorrect dtype: {audio_features.dtype}"
-            )
+        if audio_features.dtype != (torch.float16 if self.options.fp16 else torch.float32):
+            return TypeError(f"audio_features has an incorrect dtype: {audio_features.dtype}")
 
         return audio_features
 
@@ -766,15 +735,18 @@
             languages = [max(probs, key=probs.get) for probs in lang_probs]
             # FIX(funasr): sense vocie
             # if self.options.language is None:
-                # tokens[:, self.sot_index + 1] = lang_tokens  # write language tokens
+            # tokens[:, self.sot_index + 1] = lang_tokens  # write language tokens
             if self.options.language is None:
                 # tokens[:, self.sot_index + 1] = lang_tokens  # write language tokens
                 languages = "".join([f"<|{language}|>" for language in languages])
 
                 n_audio = audio_features.shape[0]
-                lang_tokens = torch.tensor([self.tokenizer.encode(languages, allowed_special="all")] * n_audio).to(
-                    audio_features.device)  # [n_audio, 1]
-                
+                lang_tokens = torch.tensor(
+                    [self.tokenizer.encode(languages, allowed_special="all")] * n_audio
+                ).to(
+                    audio_features.device
+                )  # [n_audio, 1]
+
                 tokens[:, -1:] = lang_tokens[:, :]
                 languages = [languages]
 
@@ -789,9 +761,7 @@
             for i in range(self.sample_len):
                 logits = self.inference.logits(tokens, audio_features)
 
-                if (
-                    i == 0 and self.tokenizer.no_speech is not None
-                ):  # save no_speech_probs
+                if i == 0 and self.tokenizer.no_speech is not None:  # save no_speech_probs
                     probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
                     no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
 
@@ -825,12 +795,8 @@
         languages, language_probs = self._detect_language(audio_features, tokens)
         if self.options.task == "lang_id":
             return [
-                DecodingResult(
-                    audio_features=features, language=language, language_probs=probs
-                )
-                for features, language, probs in zip(
-                    audio_features, languages, language_probs
-                )
+                DecodingResult(audio_features=features, language=language, language_probs=probs)
+                for features, language, probs in zip(audio_features, languages, language_probs)
             ]
 
         # repeat text tensors by the group size, for beam search or best-of-n sampling
@@ -850,8 +816,7 @@
         # get the final candidates for each group, and slice between the first sampled token and EOT
         tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
         tokens: List[List[Tensor]] = [
-            [t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s]
-            for s in tokens
+            [t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s] for s in tokens
         ]
 
         # select the top-ranked sample in each group
@@ -860,9 +825,7 @@
         texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]
 
         sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
-        avg_logprobs: List[float] = [
-            lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)
-        ]
+        avg_logprobs: List[float] = [lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)]
 
         fields = (
             texts,
@@ -886,9 +849,7 @@
                 temperature=self.options.temperature,
                 compression_ratio=compression_ratio(text),
             )
-            for text, language, tokens, features, avg_logprob, no_speech_prob in zip(
-                *fields
-            )
+            for text, language, tokens, features, avg_logprob, no_speech_prob in zip(*fields)
         ]
 
 

--
Gitblit v1.9.1