From e84f17adca2d8a81bc2d0229b9531e7eb0a7705c Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 26 三月 2024 12:34:26 +0800
Subject: [PATCH] update

---
 funasr/models/llm_asr_nar/model.py |   33 +++++++++++++++++++++++----------
 1 files changed, 23 insertions(+), 10 deletions(-)

diff --git a/funasr/models/llm_asr_nar/model.py b/funasr/models/llm_asr_nar/model.py
index a6096b2..30537cf 100644
--- a/funasr/models/llm_asr_nar/model.py
+++ b/funasr/models/llm_asr_nar/model.py
@@ -75,7 +75,7 @@
         if hub == "funasr":
             from funasr import AutoModel
             init_param_path = encoder_conf.get("init_param_path", "iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch")
-            model = AutoModel(model=init_param_path, model_revision="v2.0.4")
+            model = AutoModel(model=init_param_path, model_revision="master")
             # frontend = model.kwargs.get("frontend")
             model.model.decoder = None
             
@@ -264,7 +264,7 @@
             audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000),
                                                             data_type=kwargs.get("data_type", "sound"),
                                                             tokenizer=None)
-            if len(kwargs.get("data_type")) > 1:
+            if len(kwargs.get("data_type", [])) > 1:
                 audio_sample_list, text_token_int_list = audio_sample_list
                 text_token_int = text_token_int_list[0].replace(" ", "")
                 text_token_int = tokenizer.encode(text_token_int)
@@ -406,7 +406,7 @@
             from funasr import AutoModel
             init_param_path = encoder_conf.get("init_param_path",
                                                "iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch")
-            model = AutoModel(model=init_param_path, model_revision="v2.0.4")
+            model = AutoModel(model=init_param_path, model_revision="master")
             # frontend = model.kwargs.get("frontend")
             model.model.decoder = None
             
@@ -561,7 +561,7 @@
         audio_mask = kwargs.get("audio_mask", None)
         audio_token_lengths = audio_mask.sum(-1) if audio_mask is not None else None
         text_token_int = kwargs.get("text_token_int", None)
-        if audio_token_lengths is None:
+        if audio_token_lengths is None and text_token_int is not None:
             audio_token_lengths = torch.tensor([len(text_token_int)], dtype=torch.int64)
         
         batch = {"speech": speech, "speech_lengths": speech_lengths}
@@ -572,7 +572,9 @@
                                                                                        mask=enc_mask,
                                                                                        target_label_length=audio_token_lengths,
                                                                                        )
-            loss_pre = self.criterion_pre(audio_token_lengths.type_as(pre_token_length), pre_token_length)
+            loss_pre = 0.0
+            if audio_token_lengths is not None:
+                loss_pre = self.criterion_pre(audio_token_lengths.type_as(pre_token_length), pre_token_length)
         
         return pre_acoustic_embeds, pre_token_length, loss_pre
     
@@ -603,10 +605,12 @@
             audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000),
                                                             data_type=kwargs.get("data_type", "sound"),
                                                             tokenizer=None)
-            if len(kwargs.get("data_type")) > 1:
+            if len(kwargs.get("data_type", [])) > 1:
                 audio_sample_list, text_token_int_list = audio_sample_list
-                text_token_int = text_token_int_list[0].replace(" ", "")
+                text_token_int = text_token_int_list[0]
                 text_token_int = tokenizer.encode(text_token_int)
+                if text_token_int[0] == tokenizer.bos_token_id:
+                    text_token_int = text_token_int[1:]
             else:
                 text_token_int = None
             time2 = time.perf_counter()
@@ -621,24 +625,30 @@
         speech_lengths = speech_lengths.to(device=kwargs["device"])
         
         # Encoder
-        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, text_token_int=text_token_int)
+        res = self.encode(speech, speech_lengths, text_token_int=text_token_int)
+        encoder_out = res[0]
         
         # adaptor
         encoder_out = self.adaptor(encoder_out)
         
         prompt_pre = "USER: \nINSTRUCTION: {}\nINPUT: ".format(prompt)
         prompt_ids = tokenizer.encode(prompt_pre)
+        if prompt_ids[0] == tokenizer.bos_token_id:
+            prompt_ids = prompt_ids[1:]
+        # prompt_ids = prompt_ids + [tokenizer.pad_token_id]
         prompt_length = len(prompt_ids)
         prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64).to(kwargs["device"])
+        pad = torch.tensor([tokenizer.pad_token_id], dtype=torch.int64).to(kwargs["device"])
         
         if hasattr(self.llm.model, "embed_tokens"):
             inputs_embeds = self.llm.model.embed_tokens(prompt_ids)
+            pad = self.llm.model.embed_tokens(pad)
         elif hasattr(self.llm.model.model, "embed_tokens"):
             inputs_embeds = self.llm.model.model.embed_tokens(prompt_ids)
         else:
             inputs_embeds = self.llm.model.model.model.embed_tokens(prompt_ids)
         
-        inputs_embeds = torch.cat((inputs_embeds[None, :, :], encoder_out), dim=1)  # [prompt, audio]
+        inputs_embeds = torch.cat((inputs_embeds[None, :, :], encoder_out, pad[None, :, :]), dim=1)  # [prompt, audio]
         attention_mask = torch.ones(inputs_embeds.size()[:-1], dtype=torch.long).to(kwargs["device"])
         
         # model_outputs = self.llm.generate(
@@ -662,8 +672,11 @@
         preds = torch.argmax(model_outputs.logits, -1)
         text = tokenizer.batch_decode(preds, add_special_tokens=False, skip_special_tokens=True)
         
-        text = text[0].split(': ')[-1]
+        text = text[0].split(':')[-1]
         text = text.strip()
+        if text.startswith("Please\n "):
+            text = text.replace("Please\n ", "")
+            text = text.strip()
         
         # preds = torch.argmax(model_outputs.logits, -1)
         

--
Gitblit v1.9.1