gaochangfeng
2024-04-10 851e3e3ef83d0769d9bde172d8841f6b20e3e377
funasr/models/sense_voice/whisper_lib/decoding.py
@@ -10,6 +10,8 @@
from .audio import CHUNK_LENGTH
from .tokenizer import Tokenizer, get_tokenizer
from .utils import compression_ratio
from funasr.models.transformer.utils.nets_utils import to_device
if TYPE_CHECKING:
    from .model import Whisper
@@ -17,7 +19,7 @@
@torch.no_grad()
def detect_language(
    model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None
    model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None, initial_prompt = None, x = None,
) -> Tuple[Tensor, List[dict]]:
    """
    Detect the spoken language in the audio, and return them as list of strings, along with the ids
@@ -48,24 +50,34 @@
        mel = mel.unsqueeze(0)
    # skip encoder forward pass if already-encoded audio features were given
    if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state):
    # FIX(funasr): sense vocie
    if mel.shape[-1] != model.dims.n_audio_state:
        mel = model.encoder(mel)
    # forward pass using a single token, startoftranscript
    n_audio = mel.shape[0]
    x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device)  # [n_audio, 1]
    logits = model.logits(x, mel)[:, 0]
    # FIX(funasr): sense vocie
    # x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device)  # [n_audio, 1]
    if x is None:
        x = torch.tensor([tokenizer.encode(initial_prompt, allowed_special="all")] * n_audio).to(mel.device)  # [n_audio, 1]
    else:
        x = x.to(mel.device)
    logits = model.logits(x[:,:-1], mel)[:, -1]
    # collect detected languages; suppress all non-language tokens
    mask = torch.ones(logits.shape[-1], dtype=torch.bool)
    mask[list(tokenizer.all_language_tokens)] = False
    mask[tokenizer.no_speech] = False
    logits[:, mask] = -np.inf
    language_tokens = logits.argmax(dim=-1)
    language_token_probs = logits.softmax(dim=-1).cpu()
    language_probs = [
        {
            c: language_token_probs[i, j].item()
            for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes)
            for j, c in zip(list(tokenizer.all_language_tokens) + [tokenizer.no_speech], list(tokenizer.all_language_codes) + ["nospeech"])
        }
        for i in range(n_audio)
    ]
@@ -112,6 +124,10 @@
    # implementation details
    fp16: bool = True  # use fp16 for most of the calculation
    # FIX(funasr): sense vocie
    initial_prompt: str = None
    vocab_path: str = None
@dataclass(frozen=True)
@@ -520,6 +536,7 @@
            num_languages=model.num_languages,
            language=language,
            task=options.task,
            vocab_path=options.vocab_path
        )
        self.tokenizer: Tokenizer = tokenizer
        self.options: DecodingOptions = self._verify_options(options)
@@ -609,6 +626,15 @@
                + prompt_tokens[-(self.n_ctx // 2 - 1) :]
                + tokens
            )
        #FIX(funasr): sense vocie
        if initial_prompt := self.options.initial_prompt:
            if self.options.language is not None:
                initial_prompt = f"{initial_prompt}<|{self.options.language}|>"
                tokens = self.tokenizer.encode(initial_prompt, allowed_special="all")
            else:
                tokens = self.tokenizer.encode(initial_prompt, allowed_special="all")
                tokens += [0]
        return tuple(tokens)
@@ -669,11 +695,22 @@
        if self.options.language is None or self.options.task == "lang_id":
            lang_tokens, lang_probs = self.model.detect_language(
                audio_features, self.tokenizer
                audio_features, self.tokenizer, x=tokens
            )
            languages = [max(probs, key=probs.get) for probs in lang_probs]
            # FIX(funasr): sense vocie
            # if self.options.language is None:
                # tokens[:, self.sot_index + 1] = lang_tokens  # write language tokens
            if self.options.language is None:
                tokens[:, self.sot_index + 1] = lang_tokens  # write language tokens
                # tokens[:, self.sot_index + 1] = lang_tokens  # write language tokens
                languages = "".join([f"<|{language}|>" for language in languages])
                n_audio = audio_features.shape[0]
                lang_tokens = torch.tensor([self.tokenizer.encode(languages, allowed_special="all")] * n_audio).to(
                    audio_features.device)  # [n_audio, 1]
                tokens[:, -1:] = lang_tokens[:, :]
                languages = [languages]
        return languages, lang_probs