| | |
| | | |
| | | @torch.no_grad() |
| | | def detect_language( |
| | | model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None, initial_prompt = None, x = None, |
| | | model: "Whisper", |
| | | mel: Tensor, |
| | | tokenizer: Tokenizer = None, |
| | | initial_prompt=None, |
| | | x=None, |
| | | ) -> Tuple[Tensor, List[dict]]: |
| | | """ |
| | | Detect the spoken language in the audio, and return them as list of strings, along with the ids |
| | |
| | | list of dictionaries containing the probability distribution over all languages. |
| | | """ |
| | | if tokenizer is None: |
| | | tokenizer = get_tokenizer( |
| | | model.is_multilingual, num_languages=model.num_languages |
| | | ) |
| | | if ( |
| | | tokenizer.language is None |
| | | or tokenizer.language_token not in tokenizer.sot_sequence |
| | | ): |
| | | raise ValueError( |
| | | "This model doesn't have language tokens so it can't perform lang id" |
| | | ) |
| | | tokenizer = get_tokenizer(model.is_multilingual, num_languages=model.num_languages) |
| | | if tokenizer.language is None or tokenizer.language_token not in tokenizer.sot_sequence: |
| | | raise ValueError("This model doesn't have language tokens so it can't perform lang id") |
| | | |
| | | single = mel.ndim == 2 |
| | | if single: |
| | |
| | | # FIX(funasr): sense vocie |
| | | # x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1] |
| | | if x is None: |
| | | x = torch.tensor([tokenizer.encode(initial_prompt, allowed_special="all")] * n_audio).to(mel.device) # [n_audio, 1] |
| | | x = torch.tensor([tokenizer.encode(initial_prompt, allowed_special="all")] * n_audio).to( |
| | | mel.device |
| | | ) # [n_audio, 1] |
| | | |
| | | else: |
| | | x = x.to(mel.device) |
| | | # FIX(funasr): sense vocie |
| | | # logits = model.logits(x[:, :-1], mel)[:, -1] |
| | | logits = model.logits(x[:, :], mel)[:, -1] |
| | | |
| | | logits = model.logits(x[:,:-1], mel)[:, -1] |
| | | # collect detected languages; suppress all non-language tokens |
| | | mask = torch.ones(logits.shape[-1], dtype=torch.bool) |
| | | mask[list(tokenizer.all_language_tokens)] = False |
| | | mask[tokenizer.no_speech] = False |
| | | |
| | | |
| | | logits[:, mask] = -np.inf |
| | | language_tokens = logits.argmax(dim=-1) |
| | | language_token_probs = logits.softmax(dim=-1).cpu() |
| | |
| | | language_probs = [ |
| | | { |
| | | c: language_token_probs[i, j].item() |
| | | for j, c in zip(list(tokenizer.all_language_tokens) + [tokenizer.no_speech], list(tokenizer.all_language_codes) + ["nospeech"]) |
| | | for j, c in zip( |
| | | list(tokenizer.all_language_tokens) + [tokenizer.no_speech], |
| | | list(tokenizer.all_language_codes) + ["nospeech"], |
| | | ) |
| | | } |
| | | for i in range(n_audio) |
| | | ] |
| | |
| | | suppress_blank: bool = True # this will suppress blank outputs |
| | | |
| | | gain_event: bool = False # this will suppress blank outputs |
| | | gain_tokens_bg: Optional[Union[str, List[int]]] = "<|Applause|><|Laughter|>" |
| | | gain_tokens_ed: Optional[Union[str, List[int]]] = "<|/Applause|><|/Laughter|>" |
| | | gain_tokens_score: List[float] = field(default_factory=lambda: [25.0, 5.0]) #[25, 5] |
| | | gain_tokens_bg: Optional[Union[str, List[int]]] = "<|Speech|><|BGM|><|Applause|><|Laughter|>" |
| | | gain_tokens_ed: Optional[Union[str, List[int]]] = ( |
| | | "<|/Speech|><|/BGM|><|/Applause|><|/Laughter|>" |
| | | ) |
| | | gain_tokens_score: List[float] = field(default_factory=lambda: [1, 1, 25.0, 5.0]) # [25, 5] |
| | | |
| | | use_emo_threshold: bool = False # this will suppress blank outputs |
| | | emo_unk_token: Optional[Union[str, List[int]]] = "<|SPECIAL_TOKEN_1|>" |
| | | emo_target_tokens: Optional[Union[str, List[int]]] = "<|HAPPY|><|SAD|><|ANGRY|>" |
| | | emo_target_threshold: List[float] = field(default_factory=lambda: [0.1, 0.1, 0.1]) #[25, 5] |
| | | emo_target_threshold: List[float] = field(default_factory=lambda: [0.1, 0.1, 0.1]) # [25, 5] |
| | | |
| | | # timestamp sampling options |
| | | without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only |
| | |
| | | |
| | | |
| | | class SequenceRanker: |
| | | def rank( |
| | | self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]] |
| | | ) -> List[int]: |
| | | def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]) -> List[int]: |
| | | """ |
| | | Given a list of groups of samples and their cumulative log probabilities, |
| | | return the indices of the samples in each group to select as the final result |
| | |
| | | def reset(self): |
| | | """Initialize any stateful variables for decoding a new sequence""" |
| | | |
| | | def update( |
| | | self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor |
| | | ) -> Tuple[Tensor, bool]: |
| | | def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]: |
| | | """Specify how to select the next token, based on the current trace and logits |
| | | |
| | | Parameters |
| | |
| | | self.temperature = temperature |
| | | self.eot = eot |
| | | |
| | | def update( |
| | | self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor |
| | | ) -> Tuple[Tensor, bool]: |
| | | def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]: |
| | | if self.temperature == 0: |
| | | next_tokens = logits.argmax(dim=-1) |
| | | else: |
| | |
| | | self.max_candidates: int = round(beam_size * self.patience) |
| | | self.finished_sequences = None |
| | | |
| | | assert ( |
| | | self.max_candidates > 0 |
| | | ), f"Invalid beam size ({beam_size}) or patience ({patience})" |
| | | assert self.max_candidates > 0, f"Invalid beam size ({beam_size}) or patience ({patience})" |
| | | |
| | | def reset(self): |
| | | self.finished_sequences = None |
| | | |
| | | def update( |
| | | self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor |
| | | ) -> Tuple[Tensor, bool]: |
| | | def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]: |
| | | if tokens.shape[0] % self.beam_size != 0: |
| | | raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0") |
| | | |
| | |
| | | |
| | | # add newly finished sequences to self.finished_sequences |
| | | assert len(self.finished_sequences) == len(finished_sequences) |
| | | for previously_finished, newly_finished in zip( |
| | | self.finished_sequences, finished_sequences |
| | | ): |
| | | for previously_finished, newly_finished in zip(self.finished_sequences, finished_sequences): |
| | | for seq in sorted(newly_finished, key=newly_finished.get, reverse=True): |
| | | if len(previously_finished) >= self.max_candidates: |
| | | break # the candidate list is full |
| | |
| | | |
| | | # mark as completed if all audio has enough number of samples |
| | | completed = all( |
| | | len(sequences) >= self.max_candidates |
| | | for sequences in self.finished_sequences |
| | | len(sequences) >= self.max_candidates for sequences in self.finished_sequences |
| | | ) |
| | | return tokens, completed |
| | | |
| | |
| | | # collect all finished sequences, including patience, and add unfinished ones if not enough |
| | | sum_logprobs = sum_logprobs.cpu() |
| | | for i, sequences in enumerate(self.finished_sequences): |
| | | if ( |
| | | len(sequences) < self.beam_size |
| | | ): # when not enough sequences are finished |
| | | if len(sequences) < self.beam_size: # when not enough sequences are finished |
| | | for j in list(np.argsort(sum_logprobs[i]))[::-1]: |
| | | sequence = preceding_tokens[i, j].tolist() + [self.eot] |
| | | sequences[tuple(sequence)] = sum_logprobs[i][j].item() |
| | |
| | | break |
| | | |
| | | tokens: List[List[Tensor]] = [ |
| | | [torch.tensor(seq) for seq in sequences.keys()] |
| | | for sequences in self.finished_sequences |
| | | [torch.tensor(seq) for seq in sequences.keys()] for sequences in self.finished_sequences |
| | | ] |
| | | sum_logprobs: List[List[float]] = [ |
| | | list(sequences.values()) for sequences in self.finished_sequences |
| | |
| | | def apply(self, logits: Tensor, tokens: Tensor): |
| | | logits[:, self.suppress_tokens] = -np.inf |
| | | |
| | | |
| | | class GainEventToken(LogitFilter): |
| | | def __init__(self, bg_tokens: Sequence[int], ed_tokens:Sequence[int], gain_values: Sequence[float]): |
| | | def __init__( |
| | | self, bg_tokens: Sequence[int], ed_tokens: Sequence[int], gain_values: Sequence[float] |
| | | ): |
| | | self.bg_tokens = list(bg_tokens) |
| | | self.ed_tokens = list(ed_tokens) |
| | | self.gain_value = [np.log(max(ga, 1e-9)) for ga in gain_values] |
| | |
| | | sum_bg = sum([1 if x == bg else 0 for x in tokens[i]]) |
| | | sum_ed = sum([1 if x == ed else 0 for x in tokens[i]]) |
| | | logits[i, bg] += ga |
| | | if sum_bg > sum_ed or tokens[i,-1] in [bg, ed]: |
| | | if sum_bg > sum_ed or tokens[i, -1] in [bg, ed]: |
| | | logits[i, bg] = -np.inf |
| | | if sum_bg <= sum_ed: |
| | | logits[i, ed] = -np.inf |
| | | |
| | | |
| | | class ThresholdEmoToken(LogitFilter): |
| | | def __init__(self, unk_tokens: Sequence[int], emo_tokens:Sequence[int], th_values: Sequence[float]): |
| | | def __init__( |
| | | self, unk_tokens: Sequence[int], emo_tokens: Sequence[int], th_values: Sequence[float] |
| | | ): |
| | | self.unk_token = list(unk_tokens)[0] |
| | | self.emo_tokens = list(emo_tokens) |
| | | self.th_values = list(th_values) |
| | |
| | | for i in range(len(tokens)): |
| | | for emo, th in zip(self.emo_tokens, self.th_values): |
| | | if logits[i].argmax() == emo and logits[i].softmax(dim=-1)[emo] < th: |
| | | logits[i, self.unk_token] = max(logits[i, emo], logits[i, self.unk_token]) |
| | | logits[i, self.unk_token] = max(logits[i, emo], logits[i, self.unk_token]) |
| | | logits[i, emo] = -np.inf |
| | | |
| | | # for bg, ed, ga in zip(self.bg_tokens, self.ed_tokens, self.gain_value): |
| | |
| | | for k in range(tokens.shape[0]): |
| | | sampled_tokens = tokens[k, self.sample_begin :] |
| | | seq = [t for t in sampled_tokens.tolist()] |
| | | last_was_timestamp = ( |
| | | len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin |
| | | ) |
| | | penultimate_was_timestamp = ( |
| | | len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin |
| | | ) |
| | | last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin |
| | | penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin |
| | | |
| | | if last_was_timestamp: |
| | | if penultimate_was_timestamp: # has to be non-timestamp |
| | |
| | | else: # cannot be normal text tokens |
| | | logits[k, : self.tokenizer.eot] = -np.inf |
| | | |
| | | timestamps = sampled_tokens[ |
| | | sampled_tokens.ge(self.tokenizer.timestamp_begin) |
| | | ] |
| | | timestamps = sampled_tokens[sampled_tokens.ge(self.tokenizer.timestamp_begin)] |
| | | if timestamps.numel() > 0: |
| | | # timestamps shouldn't decrease; forbid timestamp tokens smaller than the last |
| | | # also force each segment to have a nonzero length, to prevent infinite looping |
| | |
| | | |
| | | # apply the `max_initial_timestamp` option |
| | | if self.max_initial_timestamp_index is not None: |
| | | last_allowed = ( |
| | | self.tokenizer.timestamp_begin + self.max_initial_timestamp_index |
| | | ) |
| | | last_allowed = self.tokenizer.timestamp_begin + self.max_initial_timestamp_index |
| | | logits[:, last_allowed + 1 :] = -np.inf |
| | | |
| | | # if sum of probability over timestamps is above any other token, sample timestamp |
| | | logprobs = F.log_softmax(logits.float(), dim=-1) |
| | | for k in range(tokens.shape[0]): |
| | | timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp( |
| | | dim=-1 |
| | | ) |
| | | timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(dim=-1) |
| | | max_text_token_logprob = logprobs[k, : self.tokenizer.timestamp_begin].max() |
| | | if timestamp_logprob > max_text_token_logprob: |
| | | logits[k, : self.tokenizer.timestamp_begin] = -np.inf |
| | |
| | | num_languages=model.num_languages, |
| | | language=language, |
| | | task=options.task, |
| | | vocab_path=options.vocab_path |
| | | vocab_path=options.vocab_path, |
| | | ) |
| | | self.tokenizer: Tokenizer = tokenizer |
| | | self.options: DecodingOptions = self._verify_options(options) |
| | |
| | | if self.options.suppress_tokens: |
| | | self.logit_filters.append(SuppressTokens(self._get_suppress_tokens())) |
| | | if self.options.gain_event: |
| | | self.logit_filters.append(GainEventToken( |
| | | self.tokenizer.encode(self.options.gain_tokens_bg, allowed_special="all"), |
| | | self.tokenizer.encode(self.options.gain_tokens_ed, allowed_special="all"), |
| | | self.options.gain_tokens_score |
| | | self.logit_filters.append( |
| | | GainEventToken( |
| | | self.tokenizer.encode(self.options.gain_tokens_bg, allowed_special="all"), |
| | | self.tokenizer.encode(self.options.gain_tokens_ed, allowed_special="all"), |
| | | self.options.gain_tokens_score, |
| | | ) |
| | | ) |
| | | if self.options.use_emo_threshold: |
| | | self.logit_filters.append(ThresholdEmoToken( |
| | | self.tokenizer.encode(self.options.emo_unk_token, allowed_special="all"), |
| | | self.tokenizer.encode(self.options.emo_target_tokens, allowed_special="all"), |
| | | self.options.emo_target_threshold |
| | | self.logit_filters.append( |
| | | ThresholdEmoToken( |
| | | self.tokenizer.encode(self.options.emo_unk_token, allowed_special="all"), |
| | | self.tokenizer.encode(self.options.emo_target_tokens, allowed_special="all"), |
| | | self.options.emo_target_threshold, |
| | | ) |
| | | ) |
| | | if not options.without_timestamps: |
| | | precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds |
| | | max_initial_timestamp_index = None |
| | | if options.max_initial_timestamp: |
| | | max_initial_timestamp_index = round( |
| | | self.options.max_initial_timestamp / precision |
| | | ) |
| | | max_initial_timestamp_index = round(self.options.max_initial_timestamp / precision) |
| | | self.logit_filters.append( |
| | | ApplyTimestampRules( |
| | | tokenizer, self.sample_begin, max_initial_timestamp_index |
| | | ) |
| | | ApplyTimestampRules(tokenizer, self.sample_begin, max_initial_timestamp_index) |
| | | ) |
| | | |
| | | def _verify_options(self, options: DecodingOptions) -> DecodingOptions: |
| | |
| | | raise ValueError("best_of with greedy sampling (T=0) is not compatible") |
| | | if options.patience is not None and options.beam_size is None: |
| | | raise ValueError("patience requires beam_size to be given") |
| | | if options.length_penalty is not None and not ( |
| | | 0 <= options.length_penalty <= 1 |
| | | ): |
| | | if options.length_penalty is not None and not (0 <= options.length_penalty <= 1): |
| | | raise ValueError("length_penalty (alpha) should be a value between 0 and 1") |
| | | |
| | | return options |
| | |
| | | |
| | | if prefix := self.options.prefix: |
| | | prefix_tokens = ( |
| | | self.tokenizer.encode(" " + prefix.strip()) |
| | | if isinstance(prefix, str) |
| | | else prefix |
| | | self.tokenizer.encode(" " + prefix.strip()) if isinstance(prefix, str) else prefix |
| | | ) |
| | | if self.sample_len is not None: |
| | | max_prefix_len = self.n_ctx // 2 - self.sample_len |
| | |
| | | |
| | | if prompt := self.options.prompt: |
| | | prompt_tokens = ( |
| | | self.tokenizer.encode(" " + prompt.strip()) |
| | | if isinstance(prompt, str) |
| | | else prompt |
| | | self.tokenizer.encode(" " + prompt.strip()) if isinstance(prompt, str) else prompt |
| | | ) |
| | | tokens = ( |
| | | [self.tokenizer.sot_prev] |
| | | + prompt_tokens[-(self.n_ctx // 2 - 1) :] |
| | | + tokens |
| | | ) |
| | | #FIX(funasr): sense vocie |
| | | tokens = [self.tokenizer.sot_prev] + prompt_tokens[-(self.n_ctx // 2 - 1) :] + tokens |
| | | # FIX(funasr): sense vocie |
| | | if initial_prompt := self.options.initial_prompt: |
| | | if self.options.language is not None: |
| | | initial_prompt = f"{initial_prompt}<|{self.options.language}|>" |
| | |
| | | else: |
| | | tokens = self.tokenizer.encode(initial_prompt, allowed_special="all") |
| | | tokens += [0] |
| | | |
| | | |
| | | return tuple(tokens) |
| | | |
| | |
| | | else: |
| | | audio_features = self.model.encoder(mel) |
| | | |
| | | if audio_features.dtype != ( |
| | | torch.float16 if self.options.fp16 else torch.float32 |
| | | ): |
| | | return TypeError( |
| | | f"audio_features has an incorrect dtype: {audio_features.dtype}" |
| | | ) |
| | | if audio_features.dtype != (torch.float16 if self.options.fp16 else torch.float32): |
| | | return TypeError(f"audio_features has an incorrect dtype: {audio_features.dtype}") |
| | | |
| | | return audio_features |
| | | |
| | |
| | | languages = [max(probs, key=probs.get) for probs in lang_probs] |
| | | # FIX(funasr): sense vocie |
| | | # if self.options.language is None: |
| | | # tokens[:, self.sot_index + 1] = lang_tokens # write language tokens |
| | | # tokens[:, self.sot_index + 1] = lang_tokens # write language tokens |
| | | if self.options.language is None: |
| | | # tokens[:, self.sot_index + 1] = lang_tokens # write language tokens |
| | | languages = "".join([f"<|{language}|>" for language in languages]) |
| | | |
| | | n_audio = audio_features.shape[0] |
| | | lang_tokens = torch.tensor([self.tokenizer.encode(languages, allowed_special="all")] * n_audio).to( |
| | | audio_features.device) # [n_audio, 1] |
| | | |
| | | lang_tokens = torch.tensor( |
| | | [self.tokenizer.encode(languages, allowed_special="all")] * n_audio |
| | | ).to( |
| | | audio_features.device |
| | | ) # [n_audio, 1] |
| | | |
| | | tokens[:, -1:] = lang_tokens[:, :] |
| | | languages = [languages] |
| | | |
| | |
| | | for i in range(self.sample_len): |
| | | logits = self.inference.logits(tokens, audio_features) |
| | | |
| | | if ( |
| | | i == 0 and self.tokenizer.no_speech is not None |
| | | ): # save no_speech_probs |
| | | if i == 0 and self.tokenizer.no_speech is not None: # save no_speech_probs |
| | | probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1) |
| | | no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist() |
| | | |
| | |
| | | languages, language_probs = self._detect_language(audio_features, tokens) |
| | | if self.options.task == "lang_id": |
| | | return [ |
| | | DecodingResult( |
| | | audio_features=features, language=language, language_probs=probs |
| | | ) |
| | | for features, language, probs in zip( |
| | | audio_features, languages, language_probs |
| | | ) |
| | | DecodingResult(audio_features=features, language=language, language_probs=probs) |
| | | for features, language, probs in zip(audio_features, languages, language_probs) |
| | | ] |
| | | |
| | | # repeat text tensors by the group size, for beam search or best-of-n sampling |
| | |
| | | # get the final candidates for each group, and slice between the first sampled token and EOT |
| | | tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs) |
| | | tokens: List[List[Tensor]] = [ |
| | | [t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s] |
| | | for s in tokens |
| | | [t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s] for s in tokens |
| | | ] |
| | | |
| | | # select the top-ranked sample in each group |
| | |
| | | texts: List[str] = [tokenizer.decode(t).strip() for t in tokens] |
| | | |
| | | sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)] |
| | | avg_logprobs: List[float] = [ |
| | | lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs) |
| | | ] |
| | | avg_logprobs: List[float] = [lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)] |
| | | |
| | | fields = ( |
| | | texts, |
| | |
| | | temperature=self.options.temperature, |
| | | compression_ratio=compression_ratio(text), |
| | | ) |
| | | for text, language, tokens, features, avg_logprob, no_speech_prob in zip( |
| | | *fields |
| | | ) |
| | | for text, language, tokens, features, avg_logprob, no_speech_prob in zip(*fields) |
| | | ] |
| | | |
| | | |