游雁
2024-10-11 6d932da239b3584b5735f4efb2dbb50b84c385db
funasr/models/whisper/model.py
@@ -7,7 +7,11 @@
import torch.nn.functional as F
from torch import Tensor
from torch import nn
import whisper
# import whisper_timestamped as whisper
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
from funasr.register import tables
@@ -24,6 +28,7 @@
@tables.register("model_classes", "Whisper-large-v1")
@tables.register("model_classes", "Whisper-large-v2")
@tables.register("model_classes", "Whisper-large-v3")
@tables.register("model_classes", "Whisper-large-v3-turbo")
@tables.register("model_classes", "WhisperWarp")
class WhisperWarp(nn.Module):
    def __init__(self, *args, **kwargs):
@@ -38,34 +43,41 @@
            dims = kwargs.get("dims", {})
            dims = whisper.model.ModelDimensions(**dims)
            model = whisper.model.Whisper(dims=dims)
        self.model = model
        self.encoder_output_size = self.model.dims.n_audio_state
    def forward(self, ):
    def forward(
        self,
    ):
        pass
    def inference(self,
                  data_in,
                  data_lengths=None,
                  key: list = None,
                  tokenizer=None,
                  frontend=None,
                  **kwargs,
                  ):
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):
        if kwargs.get("batch_size", 1) > 1:
            raise NotImplementedError("batch decoding is not implemented")
        if frontend is None and not hasattr(self, "frontend"):
            frontend_class = tables.frontend_classes.get("WhisperFrontend")
            frontend = frontend_class(n_mels=self.model.dims.n_mels, do_pad_trim=kwargs.get("do_pad_trim", True))
            frontend = frontend_class(
                n_mels=self.model.dims.n_mels, do_pad_trim=kwargs.get("do_pad_trim", True)
            )
            self.frontend = frontend
        else:
            frontend = frontend if frontend is not None else self.frontend
        meta_data = {}
        if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank":  # fbank
        if (
            isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
        ):  # fbank
            speech, speech_lengths = data_in, data_lengths
            if len(speech.shape) < 3:
                speech = speech[None, :, :]
@@ -74,13 +86,18 @@
        else:
            # extract fbank feats
            time1 = time.perf_counter()
            audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs if hasattr(frontend, "fs") else 16000, audio_fs=kwargs.get("fs", 16000),
                                                            data_type=kwargs.get("data_type", "sound"),
                                                            tokenizer=tokenizer)
            audio_sample_list = load_audio_text_image_video(
                data_in,
                fs=frontend.fs if hasattr(frontend, "fs") else 16000,
                audio_fs=kwargs.get("fs", 16000),
                data_type=kwargs.get("data_type", "sound"),
                tokenizer=tokenizer,
            )
            time2 = time.perf_counter()
            meta_data["load_data"] = f"{time2 - time1:0.3f}"
            speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
                                                   frontend=frontend)
            speech, speech_lengths = extract_fbank(
                audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend
            )
            time3 = time.perf_counter()
            meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
            frame_shift = frontend.frame_shift if hasattr(frontend, "frame_shift") else 10
@@ -96,12 +113,13 @@
        # decode the audio
        options = whisper.DecodingOptions(**kwargs.get("DecodingOptions", {}))
        result = whisper.decode(self.model, speech, options)
        result = whisper.decode(self.model, speech, options=options)
        # result = whisper.transcribe(self.model, speech)
        results = []
        result_i = {"key": key[0], "text": result.text}
        results.append(result_i)
        return results, meta_data