From fa6f60fa762f271d096b8749f3cc9bfc61a6ed48 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 23 二月 2024 14:01:44 +0800
Subject: [PATCH] update

---
 funasr/models/llm_asr/model.py |   25 +++++++++++++------------
 1 files changed, 13 insertions(+), 12 deletions(-)

diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py
index a903262..06323c6 100644
--- a/funasr/models/llm_asr/model.py
+++ b/funasr/models/llm_asr/model.py
@@ -73,7 +73,7 @@
         hub = encoder_conf.get("hub", None)
         if hub == "funasr":
             from funasr import AutoModel
-            init_param_path = encoder_conf.get("hub", "iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch")
+            init_param_path = encoder_conf.get("init_param_path", "iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch")
             model = AutoModel(model=init_param_path, model_revision="v2.0.4")
             # frontend = model.kwargs.get("frontend")
             model.model.decoder = None
@@ -179,6 +179,7 @@
 
         if input_ids is not None:
             input_ids[input_ids == -1] = 0
+            input_ids[input_ids == -100] = 0
             if hasattr(self.llm.model, "embed_tokens"):
                 inputs_embeds = self.llm.model.embed_tokens(input_ids)
             elif hasattr(self.llm.model.model, "embed_tokens"):
@@ -190,7 +191,7 @@
                 batch_size, token_num, dims = inputs_embeds.shape
                 _, l, _ = encoder_out.shape
                 encoder_outs_pad = F.pad(encoder_out, (0, 0, token_num-l-1, 1, 0, 0), value=0.0)
-                inputs_embeds = encoder_outs_pad * audio_mask[:, :, None] + inputs_embeds * (~audio_mask[:, :, None])
+                inputs_embeds = encoder_outs_pad * audio_mask[:, :, None] + inputs_embeds * (1.0-audio_mask[:, :, None])
                 inputs_embeds = F.pad(inputs_embeds[:, 1:, :], (0, 0, 0, 1, 0, 0), value=0.0)
 
         model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids)
@@ -198,11 +199,10 @@
 
 
         stats = {}
-        if self.metric:
-            with torch.no_grad():
-                preds = torch.argmax(model_outputs.logits, -1)
-                acc_att = compute_accuracy(preds[:, :-1], labels_ids[:, 1:], ignore_label=-100)
-                stats["acc"] = acc_att
+        with torch.no_grad():
+            preds = torch.argmax(model_outputs.logits, -1)
+            acc_att = compute_accuracy(preds[:, :-1], labels_ids[:, 1:], ignore_label=-100)
+            stats["acc"] = acc_att
 
         stats["loss"] = torch.clone(loss.detach())
 
@@ -221,11 +221,12 @@
 
         batch = {"speech": speech, "speech_lengths": speech_lengths}
         enc, enc_lens = self.audio_encoder.encode(**batch)
-        enc_mask = sequence_mask(enc_lens, enc.size(1), device=enc.device)[:, None, :]
-        pre_acoustic_embeds, pre_token_length, _, _ = self.audio_encoder.predictor(enc,
-                                                                           mask=enc_mask,
-                                                                           target_label_length=audio_token_lengths,
-                                                                           )
+        with autocast(False):
+            enc_mask = sequence_mask(enc_lens, enc.size(1), device=enc.device)[:, None, :]
+            pre_acoustic_embeds, pre_token_length, _, _ = self.audio_encoder.predictor(enc,
+                                                                               mask=enc_mask,
+                                                                               target_label_length=audio_token_lengths,
+                                                                               )
 
         return pre_acoustic_embeds, pre_token_length
 

--
Gitblit v1.9.1