From 87bff7ae598279a797c27323128ca00e885d674e Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 07 二月 2023 22:51:39 +0800
Subject: [PATCH] export model

---
 funasr/export/models/e2e_asr_paraformer.py   |    7 ++-----
 funasr/export/README.md                      |   28 ++++++++++++++++++++++++++--
 funasr/models/encoder/sanm_encoder.py        |    2 +-
 funasr/export/models/encoder/sanm_encoder.py |    3 ++-
 funasr/export/export_model.py                |    2 +-
 5 files changed, 32 insertions(+), 10 deletions(-)

diff --git a/funasr/export/README.md b/funasr/export/README.md
index 39a7265..cc36a60 100644
--- a/funasr/export/README.md
+++ b/funasr/export/README.md
@@ -1,7 +1,12 @@
 
 environment: ubuntu20.04-py37-torch1.11.0-tf1.15.5-1.2.0
 
-Export onnx files from modelscope
+## install modelscope and funasr
+
+The install is the same as [funasr](../../README.md)
+
+## export onnx format model
+Export model modelscope
 ```python
 from funasr.export.export_model import ASRModelExportParaformer
 
@@ -11,7 +16,26 @@
 ```
 
 
-Export onnx files from local path
+Export model from local path
+```python
+from funasr.export.export_model import ASRModelExportParaformer
+
+output_dir = "../export"
+export_model = ASRModelExportParaformer(cache_dir=output_dir, onnx=True)
+export_model.export_from_local('/root/cache/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
+```
+
+## export torchscripts format model
+Export model modelscope
+```python
+from funasr.export.export_model import ASRModelExportParaformer
+
+output_dir = "../export"
+export_model = ASRModelExportParaformer(cache_dir=output_dir, onnx=False)
+export_model.export_from_modelscope('damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
+```
+
+Export model from local path
 ```python
 from funasr.export.export_model import ASRModelExportParaformer
 
diff --git a/funasr/export/export_model.py b/funasr/export/export_model.py
index 9f5cb0e..9a599eb 100644
--- a/funasr/export/export_model.py
+++ b/funasr/export/export_model.py
@@ -20,7 +20,7 @@
         self.cache_dir = Path(cache_dir)
         self.export_config = dict(
             feats_dim=560,
-            onnx=onnx,
+            onnx=False,
         )
         logging.info("output dir: {}".format(self.cache_dir))
         self.onnx = onnx
diff --git a/funasr/export/models/e2e_asr_paraformer.py b/funasr/export/models/e2e_asr_paraformer.py
index 8388f4f..84dd9d2 100644
--- a/funasr/export/models/e2e_asr_paraformer.py
+++ b/funasr/export/models/e2e_asr_paraformer.py
@@ -63,12 +63,9 @@
 
         decoder_out, _ = self.decoder(enc, enc_len, pre_acoustic_embeds, pre_token_length)
         decoder_out = torch.log_softmax(decoder_out, dim=-1)
-        sample_ids = decoder_out.argmax(dim=-1)
+        # sample_ids = decoder_out.argmax(dim=-1)
 
-        return decoder_out, sample_ids
-    
-    # def get_output_size(self):
-    #     return self.model.encoders[0].size
+        return decoder_out, pre_token_length
 
     def get_dummy_inputs(self):
         speech = torch.randn(2, 30, self.feats_dim)
diff --git a/funasr/export/models/encoder/sanm_encoder.py b/funasr/export/models/encoder/sanm_encoder.py
index a3c9100..8a50538 100644
--- a/funasr/export/models/encoder/sanm_encoder.py
+++ b/funasr/export/models/encoder/sanm_encoder.py
@@ -22,6 +22,7 @@
         self.embed = model.embed
         self.model = model
         self.feats_dim = feats_dim
+        self._output_size = model._output_size
 
         if onnx:
             self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
@@ -62,7 +63,7 @@
                 speech: torch.Tensor,
                 speech_lengths: torch.Tensor,
                 ):
-            
+        speech = speech * self._output_size ** 0.5
         mask = self.make_pad_mask(speech_lengths)
         mask = self.prepare_mask(mask)
         if self.embed is None:
diff --git a/funasr/models/encoder/sanm_encoder.py b/funasr/models/encoder/sanm_encoder.py
index 4c4bd7c..0751a10 100644
--- a/funasr/models/encoder/sanm_encoder.py
+++ b/funasr/models/encoder/sanm_encoder.py
@@ -293,7 +293,7 @@
             position embedded tensor and mask
         """
         masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
-        xs_pad *= self.output_size()**0.5
+        xs_pad = xs_pad * self.output_size()**0.5
         if self.embed is None:
             xs_pad = xs_pad
         elif (

--
Gitblit v1.9.1