VirtuosoQ
2024-04-26 e9d2cfc3a134b00f4e98271fbee3838d1ccecbcc
funasr/models/sense_voice/whisper_lib/decoding.py
@@ -10,6 +10,8 @@
from .audio import CHUNK_LENGTH
from .tokenizer import Tokenizer, get_tokenizer
from .utils import compression_ratio
from funasr.models.transformer.utils.nets_utils import to_device
if TYPE_CHECKING:
    from .model import Whisper
@@ -58,18 +60,24 @@
    # x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device)  # [n_audio, 1]
    if x is None:
        x = torch.tensor([tokenizer.encode(initial_prompt, allowed_special="all")] * n_audio).to(mel.device)  # [n_audio, 1]
    logits = model.logits(x, mel)[:, 0]
    else:
        x = x.to(mel.device)
    logits = model.logits(x[:,:-1], mel)[:, -1]
    # collect detected languages; suppress all non-language tokens
    mask = torch.ones(logits.shape[-1], dtype=torch.bool)
    mask[list(tokenizer.all_language_tokens)] = False
    mask[tokenizer.no_speech] = False
    logits[:, mask] = -np.inf
    language_tokens = logits.argmax(dim=-1)
    language_token_probs = logits.softmax(dim=-1).cpu()
    language_probs = [
        {
            c: language_token_probs[i, j].item()
            for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes)
            for j, c in zip(list(tokenizer.all_language_tokens) + [tokenizer.no_speech], list(tokenizer.all_language_codes) + ["nospeech"])
        }
        for i in range(n_audio)
    ]
@@ -109,6 +117,16 @@
    # "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()`
    suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
    suppress_blank: bool = True  # this will suppress blank outputs
    gain_event: bool = False  # this will suppress blank outputs
    gain_tokens_bg: Optional[Union[str, List[int]]] = "<|Speech|><|BGM|><|Applause|><|Laughter|>"
    gain_tokens_ed: Optional[Union[str, List[int]]] = "<|/Speech|><|/BGM|><|/Applause|><|/Laughter|>"
    gain_tokens_score: List[float] = field(default_factory=lambda: [1, 1, 25.0, 5.0]) #[25, 5]
    use_emo_threshold: bool = False  # this will suppress blank outputs
    emo_unk_token: Optional[Union[str, List[int]]] = "<|SPECIAL_TOKEN_1|>"
    emo_target_tokens: Optional[Union[str, List[int]]] = "<|HAPPY|><|SAD|><|ANGRY|>"
    emo_target_threshold: List[float] = field(default_factory=lambda: [0.1, 0.1, 0.1]) #[25, 5]
    # timestamp sampling options
    without_timestamps: bool = False  # use <|notimestamps|> to sample text tokens only
@@ -445,6 +463,48 @@
    def apply(self, logits: Tensor, tokens: Tensor):
        logits[:, self.suppress_tokens] = -np.inf
class GainEventToken(LogitFilter):
    def __init__(self, bg_tokens: Sequence[int], ed_tokens:Sequence[int], gain_values: Sequence[float]):
        self.bg_tokens = list(bg_tokens)
        self.ed_tokens = list(ed_tokens)
        self.gain_value = [np.log(max(ga, 1e-9)) for ga in gain_values]
        assert len(self.ed_tokens) == len(self.gain_value)
        assert len(self.bg_tokens) == len(self.gain_value)
    def apply(self, logits: Tensor, tokens: Tensor):
        for i in range(len(tokens)):
            for bg, ed, ga in zip(self.bg_tokens, self.ed_tokens, self.gain_value):
                sum_bg = sum([1 if x == bg else 0 for x in tokens[i]])
                sum_ed = sum([1 if x == ed else 0 for x in tokens[i]])
                logits[i, bg] += ga
                if sum_bg > sum_ed or tokens[i,-1] in [bg, ed]:
                    logits[i, bg] = -np.inf
                if sum_bg <= sum_ed:
                    logits[i, ed] = -np.inf
class ThresholdEmoToken(LogitFilter):
    def __init__(self, unk_tokens: Sequence[int], emo_tokens:Sequence[int], th_values: Sequence[float]):
        self.unk_token = list(unk_tokens)[0]
        self.emo_tokens = list(emo_tokens)
        self.th_values = list(th_values)
        assert len(self.emo_tokens) == len(self.th_values)
    def apply(self, logits: Tensor, tokens: Tensor):
        for i in range(len(tokens)):
            for emo, th in zip(self.emo_tokens, self.th_values):
                if logits[i].argmax() == emo and logits[i].softmax(dim=-1)[emo] < th:
                    logits[i, self.unk_token] =  max(logits[i, emo], logits[i, self.unk_token])
                    logits[i, emo] = -np.inf
            # for bg, ed, ga in zip(self.bg_tokens, self.ed_tokens, self.gain_value):
            #     sum_bg = sum([1 if x == bg else 0 for x in tokens[i]])
            #     sum_ed = sum([1 if x == ed else 0 for x in tokens[i]])
            #     logits[i, bg] += ga
            #     if sum_bg > sum_ed or tokens[i,-1] in [bg, ed]:
            #         logits[i, bg] = -np.inf
            #     if sum_bg <= sum_ed:
            #         logits[i, ed] = -np.inf
class ApplyTimestampRules(LogitFilter):
    def __init__(
@@ -565,6 +625,20 @@
            self.logit_filters.append(SuppressBlank(self.tokenizer, self.sample_begin))
        if self.options.suppress_tokens:
            self.logit_filters.append(SuppressTokens(self._get_suppress_tokens()))
        if self.options.gain_event:
            self.logit_filters.append(GainEventToken(
                self.tokenizer.encode(self.options.gain_tokens_bg, allowed_special="all"),
                self.tokenizer.encode(self.options.gain_tokens_ed, allowed_special="all"),
                self.options.gain_tokens_score
                )
            )
        if self.options.use_emo_threshold:
            self.logit_filters.append(ThresholdEmoToken(
                self.tokenizer.encode(self.options.emo_unk_token, allowed_special="all"),
                self.tokenizer.encode(self.options.emo_target_tokens, allowed_special="all"),
                self.options.emo_target_threshold
                )
            )
        if not options.without_timestamps:
            precision = CHUNK_LENGTH / model.dims.n_audio_ctx  # usually 0.02 seconds
            max_initial_timestamp_index = None