zhifu gao
2024-04-08 d19f48e17478be273584853568ac101c994c37e5
funasr/models/sense_voice/whisper_lib/decoding.py
@@ -119,6 +119,7 @@
    # FIX(funasr): sense vocie
    initial_prompt: str = None
    vocab_path: str = None
@dataclass(frozen=True)
@@ -527,6 +528,7 @@
            num_languages=model.num_languages,
            language=language,
            task=options.task,
            vocab_path=options.vocab_path
        )
        self.tokenizer: Tokenizer = tokenizer
        self.options: DecodingOptions = self._verify_options(options)
@@ -616,10 +618,13 @@
                + prompt_tokens[-(self.n_ctx // 2 - 1) :]
                + tokens
            )
        #FIX(gzf): sense vocie
        #FIX(funasr): sense vocie
        if initial_prompt := self.options.initial_prompt:
            tokens = self.tokenizer.encode(initial_prompt, allowed_special="all")
            if self.options.language is None:
            if self.options.language is not None:
                initial_prompt = f"{initial_prompt}<|{self.options.language}|>"
                tokens = self.tokenizer.encode(initial_prompt, allowed_special="all")
            else:
                tokens = self.tokenizer.encode(initial_prompt, allowed_special="all")
                tokens += [0]
@@ -691,6 +696,7 @@
            if self.options.language is None:
                # tokens[:, self.sot_index + 1] = lang_tokens  # write language tokens
                languages = "".join([f"<|{language}|>" for language in languages])
                n_audio = audio_features.shape[0]
                lang_tokens = torch.tensor([self.tokenizer.encode(languages, allowed_special="all")] * n_audio).to(
                    audio_features.device)  # [n_audio, 1]