From d097d0ca45472965d4411357d52adda5657691a2 Mon Sep 17 00:00:00 2001
From: R1ckShi <shixian.shi@alibaba-inc.com>
Date: 星期四, 30 五月 2024 14:59:07 +0800
Subject: [PATCH] update

---
 funasr/utils/export_utils.py                              |    5 ++++-
 funasr/models/whisper/model.py                            |    9 ++++++++-
 examples/industrial_data_pretraining/paraformer/export.py |    2 +-
 funasr/models/contextual_paraformer/export_meta.py        |   15 +++++++++++++++
 funasr/models/sanm/attention.py                           |    2 +-
 5 files changed, 29 insertions(+), 4 deletions(-)

diff --git a/examples/industrial_data_pretraining/paraformer/export.py b/examples/industrial_data_pretraining/paraformer/export.py
index fd5938a..a84ecac 100644
--- a/examples/industrial_data_pretraining/paraformer/export.py
+++ b/examples/industrial_data_pretraining/paraformer/export.py
@@ -10,7 +10,7 @@
 from funasr import AutoModel
 
 model = AutoModel(
-    model="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
+    model="iic/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404",
 )
 
 res = model.export(type="torchscript", quantize=False)
diff --git a/funasr/models/contextual_paraformer/export_meta.py b/funasr/models/contextual_paraformer/export_meta.py
index 7543789..5fce7ac 100644
--- a/funasr/models/contextual_paraformer/export_meta.py
+++ b/funasr/models/contextual_paraformer/export_meta.py
@@ -16,6 +16,21 @@
         self.embedding = model.bias_embed
         model.bias_encoder.batch_first = False
         self.bias_encoder = model.bias_encoder
+    
+    def export_dummy_inputs(self):
+        hotword = torch.tensor(
+            [
+                [10, 11, 12, 13, 14, 10, 11, 12, 13, 14],
+                [100, 101, 0, 0, 0, 0, 0, 0, 0, 0],
+                [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+                [10, 11, 12, 13, 14, 10, 11, 12, 13, 14],
+                [100, 101, 0, 0, 0, 0, 0, 0, 0, 0],
+                [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+            ],
+            dtype=torch.int32,
+        )
+        # hotword_length = torch.tensor([10, 2, 1], dtype=torch.int32)
+        return (hotword)
 
 
 def export_rebuild_model(model, **kwargs):
diff --git a/funasr/models/sanm/attention.py b/funasr/models/sanm/attention.py
index 08f7dc7..c7e8a8e 100644
--- a/funasr/models/sanm/attention.py
+++ b/funasr/models/sanm/attention.py
@@ -780,7 +780,7 @@
         return q, k, v
 
     def forward_attention(self, value, scores, mask, ret_attn):
-        scores = scores + mask
+        scores = scores + mask.to(scores.device)
 
         self.attn = torch.softmax(scores, dim=-1)
         context_layer = torch.matmul(self.attn, value)  # (batch, head, time1, d_k)
diff --git a/funasr/models/whisper/model.py b/funasr/models/whisper/model.py
index 8e9245a..4710b9c 100644
--- a/funasr/models/whisper/model.py
+++ b/funasr/models/whisper/model.py
@@ -7,7 +7,10 @@
 import torch.nn.functional as F
 from torch import Tensor
 from torch import nn
+
 import whisper
+# import whisper_timestamped as whisper
+
 from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
 
 from funasr.register import tables
@@ -108,8 +111,12 @@
 
         # decode the audio
         options = whisper.DecodingOptions(**kwargs.get("DecodingOptions", {}))
-        result = whisper.decode(self.model, speech, options)
+        
+        result = whisper.decode(self.model, speech, language='english')
+        # result = whisper.transcribe(self.model, speech)
 
+        import pdb; pdb.set_trace()
+        
         results = []
         result_i = {"key": key[0], "text": result.text}
 
diff --git a/funasr/utils/export_utils.py b/funasr/utils/export_utils.py
index 1deeaa7..ba200a6 100644
--- a/funasr/utils/export_utils.py
+++ b/funasr/utils/export_utils.py
@@ -83,7 +83,10 @@
 
     if device == 'cuda':
         model = model.cuda()
-        dummy_input = tuple([i.cuda() for i in dummy_input])
+        if isinstance(dummy_input, torch.Tensor):
+            dummy_input = dummy_input.cuda()
+        else:
+            dummy_input = tuple([i.cuda() for i in dummy_input])
 
     # model_script = torch.jit.script(model)
     model_script = torch.jit.trace(model, dummy_input)

--
Gitblit v1.9.1