liugz18
2024-07-18 d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99
funasr/models/qwen_audio/model.py
@@ -9,10 +9,10 @@
from torch import nn
import whisper
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
from funasr.register import tables
@tables.register("model_classes", "Qwen/Qwen-Audio")
@tables.register("model_classes", "Qwen-Audio")
@@ -25,50 +25,59 @@
    https://arxiv.org/abs/2311.07919
    Modified from https://github.com/QwenLM/Qwen-Audio
    """
    def __init__(self, *args, **kwargs):
        super().__init__()
        from transformers import AutoModelForCausalLM, AutoTokenizer
        from transformers.generation import GenerationConfig
        model_or_path = kwargs.get("model_path", "QwenAudio")
        model = AutoModelForCausalLM.from_pretrained(model_or_path, device_map="cpu",
                                                     trust_remote_code=True)
        model = AutoModelForCausalLM.from_pretrained(
            model_or_path, device_map="cpu", trust_remote_code=True
        )
        tokenizer = AutoTokenizer.from_pretrained(model_or_path, trust_remote_code=True)
        self.model = model
        self.tokenizer = tokenizer
    def forward(self, ):
    def forward(
        self,
    ):
        pass
    def inference(self,
                  data_in,
                  data_lengths=None,
                  key: list = None,
                  tokenizer=None,
                  frontend=None,
                  **kwargs,
                  ):
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):
        if kwargs.get("batch_size", 1) > 1:
            raise NotImplementedError("batch decoding is not implemented")
        meta_data = {}
        # meta_data["batch_data_time"] = -1
        sp_prompt = "<|startoftranscription|><|en|><|transcribe|><|en|><|notimestamps|><|wo_itn|>"
        query = f"<audio>{data_in[0]}</audio>{sp_prompt}"
        prompt = kwargs.get(
            "prompt", "<|startoftranscription|><|en|><|transcribe|><|en|><|notimestamps|><|wo_itn|>"
        )
        query = f"<audio>{data_in[0]}</audio>{prompt}"
        audio_info = self.tokenizer.process_audio(query)
        inputs = self.tokenizer(query, return_tensors='pt', audio_info=audio_info)
        inputs = self.tokenizer(query, return_tensors="pt", audio_info=audio_info)
        inputs = inputs.to(self.model.device)
        pred = self.model.generate(**inputs, audio_info=audio_info)
        response = tokenizer.decode(pred.cpu()[0], skip_special_tokens=False, audio_info=audio_info)
        response = self.tokenizer.decode(
            pred.cpu()[0], skip_special_tokens=False, audio_info=audio_info
        )
        results = []
        result_i = {"key": key[0], "text": response}
        results.append(result_i)
        return results, meta_data
@tables.register("model_classes", "Qwen/Qwen-Audio-Chat")
@tables.register("model_classes", "Qwen/QwenAudioChat")
@@ -83,35 +92,37 @@
        Modified from https://github.com/QwenLM/Qwen-Audio
        """
        super().__init__()
        from transformers import AutoModelForCausalLM, AutoTokenizer
        from transformers.generation import GenerationConfig
        model_or_path = kwargs.get("model_path", "QwenAudio")
        bf16 = kwargs.get("bf16", False)
        fp16 = kwargs.get("fp16", False)
        model = AutoModelForCausalLM.from_pretrained(model_or_path,
                                                     device_map="cpu",
                                                     bf16=bf16,
                                                     fp16=fp16,
                                                     trust_remote_code=True)
        model = AutoModelForCausalLM.from_pretrained(
            model_or_path, device_map="cpu", bf16=bf16, fp16=fp16, trust_remote_code=True
        )
        tokenizer = AutoTokenizer.from_pretrained(model_or_path, trust_remote_code=True)
        self.model = model
        self.tokenizer = tokenizer
    def forward(self, ):
    def forward(
        self,
    ):
        pass
    def inference(self,
                  data_in,
                  data_lengths=None,
                  key: list = None,
                  tokenizer=None,
                  frontend=None,
                  **kwargs,
                  ):
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):
        if kwargs.get("batch_size", 1) > 1:
            raise NotImplementedError("batch decoding is not implemented")
        meta_data = {}
        prompt = kwargs.get("prompt", "what does the person say?")
@@ -119,10 +130,12 @@
        history = cache.get("history", None)
        if data_in[0] is not None:
            # 1st dialogue turn
            query = self.tokenizer.from_list_format([
                {'audio': data_in[0]},  # Either a local path or an url
                {'text': prompt},
            ])
            query = self.tokenizer.from_list_format(
                [
                    {"audio": data_in[0]},  # Either a local path or an url
                    {"text": prompt},
                ]
            )
        else:
            query = prompt
        response, history = self.model.chat(self.tokenizer, query=query, history=history)
@@ -132,7 +145,7 @@
        results = []
        result_i = {"key": key[0], "text": response}
        results.append(result_i)
        return results, meta_data