zhifu gao
2024-03-04 44a6b59468c552e5e554d1e7234efb5dcab0e0b4
Dev gzf (#1421)

* fixbug

* qwenaudio
3个文件已添加
131 ■■■■■ 已修改文件
funasr/models/qwen_audio/__init__.py 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/qwen_audio/model.py 85 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/qwen_audio/template.yaml 46 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/qwen_audio/__init__.py
funasr/models/qwen_audio/model.py
New file
@@ -0,0 +1,85 @@
from dataclasses import dataclass
from typing import Dict
from typing import Iterable, Optional
import time
import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor
from torch import nn
import whisper
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
from funasr.register import tables
@tables.register("model_classes", "WhisperWarp")
class WhisperWarp(nn.Module):
    def __init__(self, whisper_dims: dict, **kwargs):
        super().__init__()
        hub = kwargs.get("hub", "funasr")
        if hub == "openai":
            init_param_path = kwargs.get("init_param_path", "large-v3")
            model = whisper.load_model(init_param_path)
        else:
            dims = whisper.model.ModelDimensions(**whisper_dims)
            model = whisper.model.Whisper(dims=dims)
        self.model = model
    def forward(self, ):
        pass
    def inference(self,
                  data_in,
                  data_lengths=None,
                  key: list = None,
                  tokenizer=None,
                  frontend=None,
                  **kwargs,
                  ):
        if kwargs.get("batch_size", 1) > 1:
            raise NotImplementedError("batch decoding is not implemented")
        meta_data = {}
        if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank":  # fbank
            speech, speech_lengths = data_in, data_lengths
            if len(speech.shape) < 3:
                speech = speech[None, :, :]
            if speech_lengths is None:
                speech_lengths = speech.shape[1]
        else:
            # extract fbank feats
            time1 = time.perf_counter()
            audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000),
                                                            data_type=kwargs.get("data_type", "sound"),
                                                            tokenizer=tokenizer)
            time2 = time.perf_counter()
            meta_data["load_data"] = f"{time2 - time1:0.3f}"
            speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
                                                   frontend=frontend)
            time3 = time.perf_counter()
            meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
            frame_shift = frontend.frame_shift if hasattr(frontend, "frame_shift") else 10
            lfr_n = frontend.lfr_n if hasattr(frontend, "lfr_n") else 1
            meta_data["batch_data_time"] = speech_lengths.sum().item() * frame_shift * lfr_n / 1000
        speech = speech.to(device=kwargs["device"])[0, :, :]
        speech_lengths = speech_lengths.to(device=kwargs["device"])
        # detect the spoken language
        _, probs = self.model.detect_language(speech)
        print(f"Detected language: {max(probs, key=probs.get)}")
        # decode the audio
        options = whisper.DecodingOptions(language=kwargs.get("language", None), fp16=False)
        result = whisper.decode(self.model, speech, options)
        results = []
        result_i = {"key": key[0], "text": result.text}
        results.append(result_i)
        return results, meta_data
funasr/models/qwen_audio/template.yaml
New file
@@ -0,0 +1,46 @@
# This is an example that demonstrates how to configure a model file.
# You can modify the configuration according to your own requirements.
# to print the register_table:
# from funasr.register import tables
# tables.print()
# network architecture
model: WhisperWarp
model_conf:
    lsm_weight: 0.1
    length_normalized_loss: true
    hub: funasr # openai
    init_param_path: null # large-v2 or large-v3 if hub == "openai"
# only use for hub == funasr,
#  if hub == openai, whisper_dims is automaticall download
whisper_dims:
    'n_mels': 80
    'n_vocab': 51865
    'n_audio_ctx': 1500
    'n_audio_state': 1280
    'n_audio_head': 20
    'n_audio_layer': 32
    'n_text_ctx': 448
    'n_text_state': 1280
    'n_text_head': 20
    'n_text_layer': 32
# frontend related
frontend: WhisperFrontend
frontend_conf:
    fs: 16000
    n_mels: 80
    do_pad_trim: true
tokenizer: WhisperTokenizer
tokenizer_conf:
  language: null
  task: transcribe
  is_multilingual: true
  num_languages: 99
scope_map: ['none', "model."]