From a7d7a0f3a2e7cd44a337ced34e3536b12ccb534e Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 11 三月 2024 19:24:44 +0800
Subject: [PATCH] Dev gzf (#1467)

---
 funasr/models/paraformer/model.py                               |    8 ++--
 examples/industrial_data_pretraining/bicif_paraformer/export.sh |    6 ++-
 funasr/bin/export.py                                            |    3 +
 funasr/models/fsmn_vad_streaming/model.py                       |    2 
 runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py        |    4 +-
 funasr/auto/auto_model.py                                       |    8 ++-
 funasr/models/ct_transformer/model.py                           |    3 -
 funasr/models/paraformer_streaming/model.py                     |    6 +-
 examples/industrial_data_pretraining/bicif_paraformer/export.py |    4 +-
 README.md                                                       |    6 +-
 runtime/python/onnxruntime/funasr_onnx/punc_bin.py              |    4 +-
 runtime/python/onnxruntime/funasr_onnx/vad_bin.py               |    8 ++--
 funasr/models/ct_transformer_streaming/model.py                 |    2 
 runtime/python/onnxruntime/funasr_onnx/paraformer_online_bin.py |    4 +-
 funasr/models/bicif_paraformer/model.py                         |    7 ++-
 examples/industrial_data_pretraining/ct_transformer/export.sh   |    6 ++-
 16 files changed, 44 insertions(+), 37 deletions(-)

diff --git a/README.md b/README.md
index 9841fe1..d159050 100644
--- a/README.md
+++ b/README.md
@@ -215,14 +215,14 @@
 
 ### Command-line usage
 ```shell
-funasr-export ++model=paraformer ++quantize=false
+funasr-export ++model=paraformer ++quantize=false ++device=cpu
 ```
 
-### python
+### Python
 ```python
 from funasr import AutoModel
 
-model = AutoModel(model="paraformer")
+model = AutoModel(model="paraformer", device="cpu")
 
 res = model.export(quantize=False)
 ```
diff --git a/examples/industrial_data_pretraining/bicif_paraformer/export.py b/examples/industrial_data_pretraining/bicif_paraformer/export.py
index 138f23a..c819f7a 100644
--- a/examples/industrial_data_pretraining/bicif_paraformer/export.py
+++ b/examples/industrial_data_pretraining/bicif_paraformer/export.py
@@ -8,7 +8,7 @@
 from funasr import AutoModel
 
 model = AutoModel(model="damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
-                  model_revision="v2.0.4")
+                  model_revision="v2.0.4", device="cpu")
 
 res = model.export(type="onnx", quantize=False)
 print(res)
@@ -17,7 +17,7 @@
 # method2, inference from local path
 from funasr import AutoModel
 
-model = AutoModel(model="/Users/zhifu/.cache/modelscope/hub/iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch")
+model = AutoModel(model="/Users/zhifu/.cache/modelscope/hub/iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", device="cpu")
 
 res = model.export(type="onnx", quantize=False)
 print(res)
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/bicif_paraformer/export.sh b/examples/industrial_data_pretraining/bicif_paraformer/export.sh
index 42b6348..b6883b7 100644
--- a/examples/industrial_data_pretraining/bicif_paraformer/export.sh
+++ b/examples/industrial_data_pretraining/bicif_paraformer/export.sh
@@ -12,7 +12,8 @@
 ++model=${model} \
 ++model_revision=${model_revision} \
 ++type="onnx" \
-++quantize=false
+++quantize=false \
+++device="cpu"
 
 # method2, inference from local path
 model="/Users/zhifu/.cache/modelscope/hub/iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
@@ -20,4 +21,5 @@
 python -m funasr.bin.export \
 ++model=${model} \
 ++type="onnx" \
-++quantize=false
\ No newline at end of file
+++quantize=false \
+++device="cpu"
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/ct_transformer/export.sh b/examples/industrial_data_pretraining/ct_transformer/export.sh
index 7556458..a11cda5 100644
--- a/examples/industrial_data_pretraining/ct_transformer/export.sh
+++ b/examples/industrial_data_pretraining/ct_transformer/export.sh
@@ -12,7 +12,8 @@
 ++model=${model} \
 ++model_revision=${model_revision} \
 ++type="onnx" \
-++quantize=false
+++quantize=false \
+++device="cpu"
 
 
 # method2, inference from local path
@@ -21,4 +22,5 @@
 python -m funasr.bin.export \
 ++model=${model} \
 ++type="onnx" \
-++quantize=false
\ No newline at end of file
+++quantize=false \
+++device="cpu"
\ No newline at end of file
diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py
index c4bab03..28b9e94 100644
--- a/funasr/auto/auto_model.py
+++ b/funasr/auto/auto_model.py
@@ -156,7 +156,7 @@
             kwargs["batch_size"] = 1
         kwargs["device"] = device
         
-        if kwargs.get("ncpu", None):
+        if kwargs.get("ncpu", 4):
             torch.set_num_threads(kwargs.get("ncpu"))
         
         # build tokenizer
@@ -476,11 +476,13 @@
                calib_num: int = 100,
                opset_version: int = 14,
                **cfg):
-        os.environ['EXPORTING_MODEL'] = 'TRUE'
+    
+        device = cfg.get("device", "cpu")
+        model = self.model.to(device=device)
         kwargs = self.kwargs
         deep_update(kwargs, cfg)
+        kwargs["device"] = device
         del kwargs["model"]
-        model = self.model
         model.eval()
 
         batch_size = 1
diff --git a/funasr/bin/export.py b/funasr/bin/export.py
index 7d47664..cb160e9 100644
--- a/funasr/bin/export.py
+++ b/funasr/bin/export.py
@@ -24,7 +24,8 @@
     if kwargs.get("debug", False):
         import pdb; pdb.set_trace()
 
-
+    if "device" not in kwargs:
+        kwargs["device"] = "cpu"
     model = AutoModel(**kwargs)
     
     res = model.export(input=kwargs.get("input", None),
diff --git a/funasr/models/bicif_paraformer/model.py b/funasr/models/bicif_paraformer/model.py
index eb7318b..4802da0 100644
--- a/funasr/models/bicif_paraformer/model.py
+++ b/funasr/models/bicif_paraformer/model.py
@@ -23,7 +23,7 @@
 from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
 from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
 from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
-
+from funasr.train_utils.device_funcs import to_device
 
 if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
     from torch.cuda.amp import autocast
@@ -348,6 +348,7 @@
         max_seq_len=512,
         **kwargs,
     ):
+        self.device = kwargs.get("device")
         is_onnx = kwargs.get("type", "onnx") == "onnx"
         encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export")
         self.encoder = encoder_class(self.encoder, onnx=is_onnx)
@@ -370,14 +371,14 @@
     
         return self
 
-    def _export_forward(
+    def export_forward(
         self,
         speech: torch.Tensor,
         speech_lengths: torch.Tensor,
     ):
         # a. To device
         batch = {"speech": speech, "speech_lengths": speech_lengths}
-        # batch = to_device(batch, device=self.device)
+        batch = to_device(batch, device=self.device)
     
         enc, enc_len = self.encoder(**batch)
         mask = self.make_pad_mask(enc_len)[:, None, :]
diff --git a/funasr/models/ct_transformer/model.py b/funasr/models/ct_transformer/model.py
index 31b8c27..88ee867 100644
--- a/funasr/models/ct_transformer/model.py
+++ b/funasr/models/ct_transformer/model.py
@@ -18,7 +18,6 @@
 from funasr.models.transformer.utils.nets_utils import make_pad_mask
 from funasr.models.ct_transformer.utils import split_to_mini_sentence, split_words
 
-
 if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
     from torch.cuda.amp import autocast
 else:
@@ -378,7 +377,7 @@
         
         return self
 
-    def _export_forward(self, inputs: torch.Tensor, text_lengths: torch.Tensor):
+    def export_forward(self, inputs: torch.Tensor, text_lengths: torch.Tensor):
         """Compute loss value from buffer sequences.
 
         Args:
diff --git a/funasr/models/ct_transformer_streaming/model.py b/funasr/models/ct_transformer_streaming/model.py
index a9b2efb..4752c4b 100644
--- a/funasr/models/ct_transformer_streaming/model.py
+++ b/funasr/models/ct_transformer_streaming/model.py
@@ -186,7 +186,7 @@
     
         return self
 
-    def _export_forward(self, inputs: torch.Tensor,
+    def export_forward(self, inputs: torch.Tensor,
                 text_lengths: torch.Tensor,
                 vad_indexes: torch.Tensor,
                 sub_masks: torch.Tensor,
diff --git a/funasr/models/fsmn_vad_streaming/model.py b/funasr/models/fsmn_vad_streaming/model.py
index c3063b0..d06db20 100644
--- a/funasr/models/fsmn_vad_streaming/model.py
+++ b/funasr/models/fsmn_vad_streaming/model.py
@@ -651,7 +651,7 @@
 		
 		return self
 		
-	def _export_forward(self, feats: torch.Tensor, *args, **kwargs):
+	def export_forward(self, feats: torch.Tensor, *args, **kwargs):
 		
 		scores, out_caches = self.encoder(feats, *args)
 		
diff --git a/funasr/models/paraformer/model.py b/funasr/models/paraformer/model.py
index 586d72d..f5f0e4e 100644
--- a/funasr/models/paraformer/model.py
+++ b/funasr/models/paraformer/model.py
@@ -21,7 +21,7 @@
 from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
 from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
 from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
-
+from funasr.train_utils.device_funcs import to_device
 
 @tables.register("model_classes", "Paraformer")
 class Paraformer(torch.nn.Module):
@@ -554,7 +554,7 @@
         max_seq_len=512,
         **kwargs,
     ):
-        
+        self.device = kwargs.get("device")
         is_onnx = kwargs.get("type", "onnx") == "onnx"
         encoder_class = tables.encoder_classes.get(kwargs["encoder"]+"Export")
         self.encoder = encoder_class(self.encoder, onnx=is_onnx)
@@ -579,14 +579,14 @@
         
         return self
 
-    def _export_forward(
+    def export_forward(
         self,
         speech: torch.Tensor,
         speech_lengths: torch.Tensor,
     ):
         # a. To device
         batch = {"speech": speech, "speech_lengths": speech_lengths}
-        # batch = to_device(batch, device=self.device)
+        batch = to_device(batch, device=self.device)
     
         enc, enc_len = self.encoder(**batch)
         mask = self.make_pad_mask(enc_len)[:, None, :]
diff --git a/funasr/models/paraformer_streaming/model.py b/funasr/models/paraformer_streaming/model.py
index cebbfc1..63dba5d 100644
--- a/funasr/models/paraformer_streaming/model.py
+++ b/funasr/models/paraformer_streaming/model.py
@@ -566,7 +566,7 @@
         max_seq_len=512,
         **kwargs,
     ):
-    
+        self.device = kwargs.get("device")
         is_onnx = kwargs.get("type", "onnx") == "onnx"
         encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export")
         self.encoder = encoder_class(self.encoder, onnx=is_onnx)
@@ -612,7 +612,7 @@
     
         return encoder_model, decoder_model
 
-    def _export_encoder_forward(
+    def export_encoder_forward(
         self,
         speech: torch.Tensor,
         speech_lengths: torch.Tensor,
@@ -663,7 +663,7 @@
     def export_encoder_name(self):
         return "model.onnx"
     
-    def _export_decoder_forward(
+    def export_decoder_forward(
         self,
         enc: torch.Tensor,
         enc_len: torch.Tensor,
diff --git a/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py b/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py
index e047db9..82548ad 100644
--- a/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py
+++ b/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py
@@ -63,8 +63,8 @@
                       "For the users in China, you could install with the command:\n" \
                       "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
 
-            model = AutoModel(model=cache_dir)
-            model_dir = model.export(type="onnx", quantize=quantize, device="cpu")
+            model = AutoModel(model=model_dir)
+            model_dir = model.export(type="onnx", quantize=quantize)
             
         config_file = os.path.join(model_dir, 'config.yaml')
         cmvn_file = os.path.join(model_dir, 'am.mvn')
diff --git a/runtime/python/onnxruntime/funasr_onnx/paraformer_online_bin.py b/runtime/python/onnxruntime/funasr_onnx/paraformer_online_bin.py
index 7da5afc..6925960 100644
--- a/runtime/python/onnxruntime/funasr_onnx/paraformer_online_bin.py
+++ b/runtime/python/onnxruntime/funasr_onnx/paraformer_online_bin.py
@@ -55,8 +55,8 @@
                       "For the users in China, you could install with the command:\n" \
                       "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
 
-            model = AutoModel(model=cache_dir)
-            model_dir = model.export(type="onnx", quantize=quantize, device="cpu")
+            model = AutoModel(model=model_dir)
+            model_dir = model.export(type="onnx", quantize=quantize)
 
         config_file = os.path.join(model_dir, 'config.yaml')
         cmvn_file = os.path.join(model_dir, 'am.mvn')
diff --git a/runtime/python/onnxruntime/funasr_onnx/punc_bin.py b/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
index 4e1014f..b1aca6e 100644
--- a/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
+++ b/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
@@ -55,8 +55,8 @@
                       "For the users in China, you could install with the command:\n" \
                       "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
 
-            model = AutoModel(model=cache_dir)
-            model_dir = model.export(type="onnx", quantize=quantize, device="cpu")
+            model = AutoModel(model=model_dir)
+            model_dir = model.export(type="onnx", quantize=quantize)
             
         config_file = os.path.join(model_dir, 'punc.yaml')
         config = read_yaml(config_file)
diff --git a/runtime/python/onnxruntime/funasr_onnx/vad_bin.py b/runtime/python/onnxruntime/funasr_onnx/vad_bin.py
index af32b1d..384f377 100644
--- a/runtime/python/onnxruntime/funasr_onnx/vad_bin.py
+++ b/runtime/python/onnxruntime/funasr_onnx/vad_bin.py
@@ -61,8 +61,8 @@
 				      "For the users in China, you could install with the command:\n" \
 				      "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
 			
-			model = AutoModel(model=cache_dir)
-			model_dir = model.export(type="onnx", quantize=quantize, device="cpu")
+			model = AutoModel(model=model_dir)
+			model_dir = model.export(type="onnx", quantize=quantize)
 		config_file = os.path.join(model_dir, 'vad.yaml')
 		cmvn_file = os.path.join(model_dir, 'vad.mvn')
 		config = read_yaml(config_file)
@@ -225,8 +225,8 @@
 				      "For the users in China, you could install with the command:\n" \
 				      "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
 			
-			model = AutoModel(model=cache_dir)
-			model_dir = model.export(type="onnx", quantize=quantize, device="cpu")
+			model = AutoModel(model=model_dir)
+			model_dir = model.export(type="onnx", quantize=quantize)
 			
 		config_file = os.path.join(model_dir, 'vad.yaml')
 		cmvn_file = os.path.join(model_dir, 'vad.mvn')

--
Gitblit v1.9.1