From aba47683fd4b2984dbff7fc79b0f532fc2d9f6b7 Mon Sep 17 00:00:00 2001
From: Yabin Li <wucong.lyb@alibaba-inc.com>
Date: 星期一, 04 三月 2024 16:44:49 +0800
Subject: [PATCH] Update SDK_advanced_guide_offline_zh.md

---
 funasr/models/llm_asr/model.py |   46 +++++++++++++++++++++++++++-------------------
 1 files changed, 27 insertions(+), 19 deletions(-)

diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py
index 2b6db96..4139d8c 100644
--- a/funasr/models/llm_asr/model.py
+++ b/funasr/models/llm_asr/model.py
@@ -20,8 +20,8 @@
 from funasr.register import tables
 
 
-@tables.register("model_classes", "LLMASRNAR")
-class LLMASRNAR(nn.Module):
+@tables.register("model_classes", "LLMASR")
+class LLMASR(nn.Module):
     """ """
     
     def __init__(
@@ -294,24 +294,32 @@
         inputs_embeds = torch.cat((inputs_embeds[None, :, :], encoder_out), dim=1)  # [prompt, audio]
         attention_mask = torch.ones(inputs_embeds.size()[:-1], dtype=torch.long).to(kwargs["device"])
         
-        model_outputs = self.llm.generate(
-            inputs_embeds=inputs_embeds,
-            max_length=kwargs.get("max_length", 200),
-            max_new_tokens=kwargs.get("max_new_tokens", 200),
-            num_beams=kwargs.get("num_beams", 4),
-            do_sample=kwargs.get("do_sample", False),
-            min_length=kwargs.get("min_length", 1),
-            top_p=kwargs.get("top_p", 1.0),
-            repetition_penalty=kwargs.get("repetition_penalty", 1.0),
-            length_penalty=kwargs.get("length_penalty", 1.0),
-            temperature=kwargs.get("temperature", 1.0),
-            attention_mask=attention_mask,
-            bos_token_id=tokenizer.bos_token_id,
-            eos_token_id=tokenizer.eos_token_id,
-            pad_token_id=tokenizer.pad_token_id
-        )
+        # model_outputs = self.llm.generate(
+        #     inputs_embeds=inputs_embeds,
+        #     max_length=kwargs.get("max_length", 200),
+        #     max_new_tokens=kwargs.get("max_new_tokens", 200),
+        #     num_beams=kwargs.get("num_beams", 4),
+        #     do_sample=kwargs.get("do_sample", False),
+        #     min_length=kwargs.get("min_length", 1),
+        #     top_p=kwargs.get("top_p", 1.0),
+        #     repetition_penalty=kwargs.get("repetition_penalty", 1.0),
+        #     length_penalty=kwargs.get("length_penalty", 1.0),
+        #     temperature=kwargs.get("temperature", 1.0),
+        #     attention_mask=attention_mask,
+        #     bos_token_id=tokenizer.bos_token_id,
+        #     eos_token_id=tokenizer.eos_token_id,
+        #     pad_token_id=tokenizer.pad_token_id
+        # )
 
-        text = tokenizer.batch_decode(model_outputs, add_special_tokens=False, skip_special_tokens=True)
+
+        model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=None)
+        preds = torch.argmax(model_outputs.logits, -1)
+        text = tokenizer.batch_decode(preds, add_special_tokens=False, skip_special_tokens=True)
+
+        text = text[0].split(': ')[-1]
+        text = text.strip()
+        
+        # preds = torch.argmax(model_outputs.logits, -1)
         
         ibest_writer = None
         if kwargs.get("output_dir") is not None:

--
Gitblit v1.9.1