| | |
| | | from torch import nn |
| | | import whisper |
| | | from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank |
| | | from transformers import AutoModelForCausalLM, AutoTokenizer |
| | | from transformers.generation import GenerationConfig |
| | | |
| | | from funasr.register import tables |
| | | |
| | | |
| | | |
| | | @tables.register("model_classes", "Qwen/Qwen-Audio") |
| | | @tables.register("model_classes", "Qwen-Audio") |
| | | @tables.register("model_classes", "Qwen/QwenAudio") |
| | | @tables.register("model_classes", "QwenAudio") |
| | | @tables.register("model_classes", "QwenAudioWarp") |
| | | class WhisperWarp(nn.Module): |
| | | def __init__(self, whisper_dims: dict, **kwargs): |
| | | class QwenAudioWarp(nn.Module): |
| | | def __init__(self, *args, **kwargs): |
| | | super().__init__() |
| | | hub = kwargs.get("hub", "funasr") |
| | | if hub == "openai": |
| | | init_param_path = kwargs.get("init_param_path", "large-v3") |
| | | model = whisper.load_model(init_param_path) |
| | | else: |
| | | dims = whisper.model.ModelDimensions(**whisper_dims) |
| | | model = whisper.model.Whisper(dims=dims) |
| | | |
| | | model_or_path = kwargs.get("model_path", "QwenAudio") |
| | | model = AutoModelForCausalLM.from_pretrained(model_or_path, device_map="cpu", |
| | | trust_remote_code=True) |
| | | tokenizer = AutoTokenizer.from_pretrained(model_or_path, trust_remote_code=True) |
| | | |
| | | |
| | | self.model = model |
| | | self.tokenizer = tokenizer |
| | | |
| | | def forward(self, ): |
| | | pass |
| | | |
| | | def inference(self, |
| | | data_in, |
| | | data_lengths=None, |
| | | key: list = None, |
| | | tokenizer=None, |
| | | frontend=None, |
| | | **kwargs, |
| | | ): |
| | | if kwargs.get("batch_size", 1) > 1: |
| | | raise NotImplementedError("batch decoding is not implemented") |
| | | |
| | | |
| | | meta_data = {} |
| | | # meta_data["batch_data_time"] = -1 |
| | | |
| | | sp_prompt = "<|startoftranscription|><|en|><|transcribe|><|en|><|notimestamps|><|wo_itn|>" |
| | | query = f"<audio>{data_in[0]}</audio>{sp_prompt}" |
| | | audio_info = self.tokenizer.process_audio(query) |
| | | inputs = self.tokenizer(query, return_tensors='pt', audio_info=audio_info) |
| | | inputs = inputs.to(self.model.device) |
| | | pred = self.model.generate(**inputs, audio_info=audio_info) |
| | | response = tokenizer.decode(pred.cpu()[0], skip_special_tokens=False, audio_info=audio_info) |
| | | |
| | | results = [] |
| | | result_i = {"key": key[0], "text": response} |
| | | |
| | | results.append(result_i) |
| | | |
| | | return results, meta_data |
| | | |
| | | @tables.register("model_classes", "Qwen/Qwen-Audio-Chat") |
| | | @tables.register("model_classes", "Qwen/QwenAudioChat") |
| | | @tables.register("model_classes", "Qwen-Audio-Chat") |
| | | @tables.register("model_classes", "QwenAudioChat") |
| | | @tables.register("model_classes", "QwenAudioChatWarp") |
| | | class QwenAudioChatWarp(nn.Module): |
| | | def __init__(self, *args, **kwargs): |
| | | super().__init__() |
| | | |
| | | model_or_path = kwargs.get("model_path", "QwenAudio") |
| | | bf16 = kwargs.get("bf16", False) |
| | | fp16 = kwargs.get("fp16", False) |
| | | model = AutoModelForCausalLM.from_pretrained(model_or_path, |
| | | device_map="cpu", |
| | | bf16=bf16, |
| | | fp16=fp16, |
| | | trust_remote_code=True) |
| | | tokenizer = AutoTokenizer.from_pretrained(model_or_path, trust_remote_code=True) |
| | | |
| | | self.model = model |
| | | self.tokenizer = tokenizer |
| | | |
| | | def forward(self, ): |
| | | pass |
| | | |
| | |
| | | ): |
| | | if kwargs.get("batch_size", 1) > 1: |
| | | raise NotImplementedError("batch decoding is not implemented") |
| | | |
| | | |
| | | |
| | | meta_data = {} |
| | | if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank": # fbank |
| | | speech, speech_lengths = data_in, data_lengths |
| | | if len(speech.shape) < 3: |
| | | speech = speech[None, :, :] |
| | | if speech_lengths is None: |
| | | speech_lengths = speech.shape[1] |
| | | |
| | | prompt = kwargs.get("prompt", "what does the person say?") |
| | | cache = kwargs.get("cache", {}) |
| | | history = cache.get("history", None) |
| | | if data_in[0] is not None: |
| | | # 1st dialogue turn |
| | | query = self.tokenizer.from_list_format([ |
| | | {'audio': data_in[0]}, # Either a local path or an url |
| | | {'text': prompt}, |
| | | ]) |
| | | else: |
| | | # extract fbank feats |
| | | time1 = time.perf_counter() |
| | | audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000), |
| | | data_type=kwargs.get("data_type", "sound"), |
| | | tokenizer=tokenizer) |
| | | time2 = time.perf_counter() |
| | | meta_data["load_data"] = f"{time2 - time1:0.3f}" |
| | | speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), |
| | | frontend=frontend) |
| | | time3 = time.perf_counter() |
| | | meta_data["extract_feat"] = f"{time3 - time2:0.3f}" |
| | | frame_shift = frontend.frame_shift if hasattr(frontend, "frame_shift") else 10 |
| | | lfr_n = frontend.lfr_n if hasattr(frontend, "lfr_n") else 1 |
| | | meta_data["batch_data_time"] = speech_lengths.sum().item() * frame_shift * lfr_n / 1000 |
| | | |
| | | speech = speech.to(device=kwargs["device"])[0, :, :] |
| | | speech_lengths = speech_lengths.to(device=kwargs["device"]) |
| | | |
| | | # detect the spoken language |
| | | _, probs = self.model.detect_language(speech) |
| | | print(f"Detected language: {max(probs, key=probs.get)}") |
| | | |
| | | # decode the audio |
| | | options = whisper.DecodingOptions(language=kwargs.get("language", None), fp16=False) |
| | | result = whisper.decode(self.model, speech, options) |
| | | query = prompt |
| | | response, history = self.model.chat(self.tokenizer, query=query, history=history) |
| | | cache["history"] = history |
| | | # print(response) |
| | | # The person says: "mister quilter is the apostle of the middle classes and we are glad to welcome his gospel". |
| | | |
| | | results = [] |
| | | result_i = {"key": key[0], "text": result.text} |
| | | |
| | | result_i = {"key": key[0], "text": response} |
| | | |
| | | results.append(result_i) |
| | | |
| | | |
| | | return results, meta_data |
| | | |