From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365

---
 funasr/models/qwen_audio/model.py |  111 +++++++++++++++++++++++++++++++------------------------
 1 files changed, 62 insertions(+), 49 deletions(-)

diff --git a/funasr/models/qwen_audio/model.py b/funasr/models/qwen_audio/model.py
index e419b1e..b0af456 100644
--- a/funasr/models/qwen_audio/model.py
+++ b/funasr/models/qwen_audio/model.py
@@ -9,10 +9,10 @@
 from torch import nn
 import whisper
 from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
-from transformers import AutoModelForCausalLM, AutoTokenizer
-from transformers.generation import GenerationConfig
+
 
 from funasr.register import tables
+
 
 @tables.register("model_classes", "Qwen/Qwen-Audio")
 @tables.register("model_classes", "Qwen-Audio")
@@ -25,50 +25,59 @@
     https://arxiv.org/abs/2311.07919
     Modified from https://github.com/QwenLM/Qwen-Audio
     """
+
     def __init__(self, *args, **kwargs):
         super().__init__()
+        from transformers import AutoModelForCausalLM, AutoTokenizer
+        from transformers.generation import GenerationConfig
 
         model_or_path = kwargs.get("model_path", "QwenAudio")
-        model = AutoModelForCausalLM.from_pretrained(model_or_path, device_map="cpu",
-                                                     trust_remote_code=True)
+        model = AutoModelForCausalLM.from_pretrained(
+            model_or_path, device_map="cpu", trust_remote_code=True
+        )
         tokenizer = AutoTokenizer.from_pretrained(model_or_path, trust_remote_code=True)
 
-        
         self.model = model
         self.tokenizer = tokenizer
-        
-    def forward(self, ):
+
+    def forward(
+        self,
+    ):
         pass
 
-    def inference(self,
-                  data_in,
-                  data_lengths=None,
-                  key: list = None,
-                  tokenizer=None,
-                  frontend=None,
-                  **kwargs,
-                  ):
+    def inference(
+        self,
+        data_in,
+        data_lengths=None,
+        key: list = None,
+        tokenizer=None,
+        frontend=None,
+        **kwargs,
+    ):
         if kwargs.get("batch_size", 1) > 1:
             raise NotImplementedError("batch decoding is not implemented")
-    
 
         meta_data = {}
         # meta_data["batch_data_time"] = -1
-
-        sp_prompt = "<|startoftranscription|><|en|><|transcribe|><|en|><|notimestamps|><|wo_itn|>"
-        query = f"<audio>{data_in[0]}</audio>{sp_prompt}"
+        prompt = kwargs.get(
+            "prompt", "<|startoftranscription|><|en|><|transcribe|><|en|><|notimestamps|><|wo_itn|>"
+        )
+        query = f"<audio>{data_in[0]}</audio>{prompt}"
         audio_info = self.tokenizer.process_audio(query)
-        inputs = self.tokenizer(query, return_tensors='pt', audio_info=audio_info)
+        inputs = self.tokenizer(query, return_tensors="pt", audio_info=audio_info)
         inputs = inputs.to(self.model.device)
         pred = self.model.generate(**inputs, audio_info=audio_info)
-        response = tokenizer.decode(pred.cpu()[0], skip_special_tokens=False, audio_info=audio_info)
+        response = self.tokenizer.decode(
+            pred.cpu()[0], skip_special_tokens=False, audio_info=audio_info
+        )
 
         results = []
         result_i = {"key": key[0], "text": response}
-    
+
         results.append(result_i)
-    
+
         return results, meta_data
+
 
 @tables.register("model_classes", "Qwen/Qwen-Audio-Chat")
 @tables.register("model_classes", "Qwen/QwenAudioChat")
@@ -83,35 +92,37 @@
         Modified from https://github.com/QwenLM/Qwen-Audio
         """
         super().__init__()
-        
+        from transformers import AutoModelForCausalLM, AutoTokenizer
+        from transformers.generation import GenerationConfig
+
         model_or_path = kwargs.get("model_path", "QwenAudio")
         bf16 = kwargs.get("bf16", False)
         fp16 = kwargs.get("fp16", False)
-        model = AutoModelForCausalLM.from_pretrained(model_or_path,
-                                                     device_map="cpu",
-                                                     bf16=bf16,
-                                                     fp16=fp16,
-                                                     trust_remote_code=True)
+        model = AutoModelForCausalLM.from_pretrained(
+            model_or_path, device_map="cpu", bf16=bf16, fp16=fp16, trust_remote_code=True
+        )
         tokenizer = AutoTokenizer.from_pretrained(model_or_path, trust_remote_code=True)
-        
+
         self.model = model
         self.tokenizer = tokenizer
-    
-    def forward(self, ):
+
+    def forward(
+        self,
+    ):
         pass
-    
-    def inference(self,
-                  data_in,
-                  data_lengths=None,
-                  key: list = None,
-                  tokenizer=None,
-                  frontend=None,
-                  **kwargs,
-                  ):
+
+    def inference(
+        self,
+        data_in,
+        data_lengths=None,
+        key: list = None,
+        tokenizer=None,
+        frontend=None,
+        **kwargs,
+    ):
         if kwargs.get("batch_size", 1) > 1:
             raise NotImplementedError("batch decoding is not implemented")
-        
-        
+
         meta_data = {}
 
         prompt = kwargs.get("prompt", "what does the person say?")
@@ -119,10 +130,12 @@
         history = cache.get("history", None)
         if data_in[0] is not None:
             # 1st dialogue turn
-            query = self.tokenizer.from_list_format([
-                {'audio': data_in[0]},  # Either a local path or an url
-                {'text': prompt},
-            ])
+            query = self.tokenizer.from_list_format(
+                [
+                    {"audio": data_in[0]},  # Either a local path or an url
+                    {"text": prompt},
+                ]
+            )
         else:
             query = prompt
         response, history = self.model.chat(self.tokenizer, query=query, history=history)
@@ -132,7 +145,7 @@
 
         results = []
         result_i = {"key": key[0], "text": response}
-        
+
         results.append(result_i)
-        
+
         return results, meta_data

--
Gitblit v1.9.1