From a7d7a0f3a2e7cd44a337ced34e3536b12ccb534e Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 11 三月 2024 19:24:44 +0800
Subject: [PATCH] Dev gzf (#1467)
---
funasr/models/paraformer/model.py | 8 ++--
examples/industrial_data_pretraining/bicif_paraformer/export.sh | 6 ++-
funasr/bin/export.py | 3 +
funasr/models/fsmn_vad_streaming/model.py | 2
runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py | 4 +-
funasr/auto/auto_model.py | 8 ++-
funasr/models/ct_transformer/model.py | 3 -
funasr/models/paraformer_streaming/model.py | 6 +-
examples/industrial_data_pretraining/bicif_paraformer/export.py | 4 +-
README.md | 6 +-
runtime/python/onnxruntime/funasr_onnx/punc_bin.py | 4 +-
runtime/python/onnxruntime/funasr_onnx/vad_bin.py | 8 ++--
funasr/models/ct_transformer_streaming/model.py | 2
runtime/python/onnxruntime/funasr_onnx/paraformer_online_bin.py | 4 +-
funasr/models/bicif_paraformer/model.py | 7 ++-
examples/industrial_data_pretraining/ct_transformer/export.sh | 6 ++-
16 files changed, 44 insertions(+), 37 deletions(-)
diff --git a/README.md b/README.md
index 9841fe1..d159050 100644
--- a/README.md
+++ b/README.md
@@ -215,14 +215,14 @@
### Command-line usage
```shell
-funasr-export ++model=paraformer ++quantize=false
+funasr-export ++model=paraformer ++quantize=false ++device=cpu
```
-### python
+### Python
```python
from funasr import AutoModel
-model = AutoModel(model="paraformer")
+model = AutoModel(model="paraformer", device="cpu")
res = model.export(quantize=False)
```
diff --git a/examples/industrial_data_pretraining/bicif_paraformer/export.py b/examples/industrial_data_pretraining/bicif_paraformer/export.py
index 138f23a..c819f7a 100644
--- a/examples/industrial_data_pretraining/bicif_paraformer/export.py
+++ b/examples/industrial_data_pretraining/bicif_paraformer/export.py
@@ -8,7 +8,7 @@
from funasr import AutoModel
model = AutoModel(model="damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
- model_revision="v2.0.4")
+ model_revision="v2.0.4", device="cpu")
res = model.export(type="onnx", quantize=False)
print(res)
@@ -17,7 +17,7 @@
# method2, inference from local path
from funasr import AutoModel
-model = AutoModel(model="/Users/zhifu/.cache/modelscope/hub/iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch")
+model = AutoModel(model="/Users/zhifu/.cache/modelscope/hub/iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", device="cpu")
res = model.export(type="onnx", quantize=False)
print(res)
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/bicif_paraformer/export.sh b/examples/industrial_data_pretraining/bicif_paraformer/export.sh
index 42b6348..b6883b7 100644
--- a/examples/industrial_data_pretraining/bicif_paraformer/export.sh
+++ b/examples/industrial_data_pretraining/bicif_paraformer/export.sh
@@ -12,7 +12,8 @@
++model=${model} \
++model_revision=${model_revision} \
++type="onnx" \
-++quantize=false
+++quantize=false \
+++device="cpu"
# method2, inference from local path
model="/Users/zhifu/.cache/modelscope/hub/iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
@@ -20,4 +21,5 @@
python -m funasr.bin.export \
++model=${model} \
++type="onnx" \
-++quantize=false
\ No newline at end of file
+++quantize=false \
+++device="cpu"
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/ct_transformer/export.sh b/examples/industrial_data_pretraining/ct_transformer/export.sh
index 7556458..a11cda5 100644
--- a/examples/industrial_data_pretraining/ct_transformer/export.sh
+++ b/examples/industrial_data_pretraining/ct_transformer/export.sh
@@ -12,7 +12,8 @@
++model=${model} \
++model_revision=${model_revision} \
++type="onnx" \
-++quantize=false
+++quantize=false \
+++device="cpu"
# method2, inference from local path
@@ -21,4 +22,5 @@
python -m funasr.bin.export \
++model=${model} \
++type="onnx" \
-++quantize=false
\ No newline at end of file
+++quantize=false \
+++device="cpu"
\ No newline at end of file
diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py
index c4bab03..28b9e94 100644
--- a/funasr/auto/auto_model.py
+++ b/funasr/auto/auto_model.py
@@ -156,7 +156,7 @@
kwargs["batch_size"] = 1
kwargs["device"] = device
- if kwargs.get("ncpu", None):
+ if kwargs.get("ncpu", 4):
torch.set_num_threads(kwargs.get("ncpu"))
# build tokenizer
@@ -476,11 +476,13 @@
calib_num: int = 100,
opset_version: int = 14,
**cfg):
- os.environ['EXPORTING_MODEL'] = 'TRUE'
+
+ device = cfg.get("device", "cpu")
+ model = self.model.to(device=device)
kwargs = self.kwargs
deep_update(kwargs, cfg)
+ kwargs["device"] = device
del kwargs["model"]
- model = self.model
model.eval()
batch_size = 1
diff --git a/funasr/bin/export.py b/funasr/bin/export.py
index 7d47664..cb160e9 100644
--- a/funasr/bin/export.py
+++ b/funasr/bin/export.py
@@ -24,7 +24,8 @@
if kwargs.get("debug", False):
import pdb; pdb.set_trace()
-
+ if "device" not in kwargs:
+ kwargs["device"] = "cpu"
model = AutoModel(**kwargs)
res = model.export(input=kwargs.get("input", None),
diff --git a/funasr/models/bicif_paraformer/model.py b/funasr/models/bicif_paraformer/model.py
index eb7318b..4802da0 100644
--- a/funasr/models/bicif_paraformer/model.py
+++ b/funasr/models/bicif_paraformer/model.py
@@ -23,7 +23,7 @@
from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
-
+from funasr.train_utils.device_funcs import to_device
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
@@ -348,6 +348,7 @@
max_seq_len=512,
**kwargs,
):
+ self.device = kwargs.get("device")
is_onnx = kwargs.get("type", "onnx") == "onnx"
encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export")
self.encoder = encoder_class(self.encoder, onnx=is_onnx)
@@ -370,14 +371,14 @@
return self
- def _export_forward(
+ def export_forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
):
# a. To device
batch = {"speech": speech, "speech_lengths": speech_lengths}
- # batch = to_device(batch, device=self.device)
+ batch = to_device(batch, device=self.device)
enc, enc_len = self.encoder(**batch)
mask = self.make_pad_mask(enc_len)[:, None, :]
diff --git a/funasr/models/ct_transformer/model.py b/funasr/models/ct_transformer/model.py
index 31b8c27..88ee867 100644
--- a/funasr/models/ct_transformer/model.py
+++ b/funasr/models/ct_transformer/model.py
@@ -18,7 +18,6 @@
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.models.ct_transformer.utils import split_to_mini_sentence, split_words
-
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
else:
@@ -378,7 +377,7 @@
return self
- def _export_forward(self, inputs: torch.Tensor, text_lengths: torch.Tensor):
+ def export_forward(self, inputs: torch.Tensor, text_lengths: torch.Tensor):
"""Compute loss value from buffer sequences.
Args:
diff --git a/funasr/models/ct_transformer_streaming/model.py b/funasr/models/ct_transformer_streaming/model.py
index a9b2efb..4752c4b 100644
--- a/funasr/models/ct_transformer_streaming/model.py
+++ b/funasr/models/ct_transformer_streaming/model.py
@@ -186,7 +186,7 @@
return self
- def _export_forward(self, inputs: torch.Tensor,
+ def export_forward(self, inputs: torch.Tensor,
text_lengths: torch.Tensor,
vad_indexes: torch.Tensor,
sub_masks: torch.Tensor,
diff --git a/funasr/models/fsmn_vad_streaming/model.py b/funasr/models/fsmn_vad_streaming/model.py
index c3063b0..d06db20 100644
--- a/funasr/models/fsmn_vad_streaming/model.py
+++ b/funasr/models/fsmn_vad_streaming/model.py
@@ -651,7 +651,7 @@
return self
- def _export_forward(self, feats: torch.Tensor, *args, **kwargs):
+ def export_forward(self, feats: torch.Tensor, *args, **kwargs):
scores, out_caches = self.encoder(feats, *args)
diff --git a/funasr/models/paraformer/model.py b/funasr/models/paraformer/model.py
index 586d72d..f5f0e4e 100644
--- a/funasr/models/paraformer/model.py
+++ b/funasr/models/paraformer/model.py
@@ -21,7 +21,7 @@
from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
-
+from funasr.train_utils.device_funcs import to_device
@tables.register("model_classes", "Paraformer")
class Paraformer(torch.nn.Module):
@@ -554,7 +554,7 @@
max_seq_len=512,
**kwargs,
):
-
+ self.device = kwargs.get("device")
is_onnx = kwargs.get("type", "onnx") == "onnx"
encoder_class = tables.encoder_classes.get(kwargs["encoder"]+"Export")
self.encoder = encoder_class(self.encoder, onnx=is_onnx)
@@ -579,14 +579,14 @@
return self
- def _export_forward(
+ def export_forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
):
# a. To device
batch = {"speech": speech, "speech_lengths": speech_lengths}
- # batch = to_device(batch, device=self.device)
+ batch = to_device(batch, device=self.device)
enc, enc_len = self.encoder(**batch)
mask = self.make_pad_mask(enc_len)[:, None, :]
diff --git a/funasr/models/paraformer_streaming/model.py b/funasr/models/paraformer_streaming/model.py
index cebbfc1..63dba5d 100644
--- a/funasr/models/paraformer_streaming/model.py
+++ b/funasr/models/paraformer_streaming/model.py
@@ -566,7 +566,7 @@
max_seq_len=512,
**kwargs,
):
-
+ self.device = kwargs.get("device")
is_onnx = kwargs.get("type", "onnx") == "onnx"
encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export")
self.encoder = encoder_class(self.encoder, onnx=is_onnx)
@@ -612,7 +612,7 @@
return encoder_model, decoder_model
- def _export_encoder_forward(
+ def export_encoder_forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
@@ -663,7 +663,7 @@
def export_encoder_name(self):
return "model.onnx"
- def _export_decoder_forward(
+ def export_decoder_forward(
self,
enc: torch.Tensor,
enc_len: torch.Tensor,
diff --git a/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py b/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py
index e047db9..82548ad 100644
--- a/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py
+++ b/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py
@@ -63,8 +63,8 @@
"For the users in China, you could install with the command:\n" \
"\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
- model = AutoModel(model=cache_dir)
- model_dir = model.export(type="onnx", quantize=quantize, device="cpu")
+ model = AutoModel(model=model_dir)
+ model_dir = model.export(type="onnx", quantize=quantize)
config_file = os.path.join(model_dir, 'config.yaml')
cmvn_file = os.path.join(model_dir, 'am.mvn')
diff --git a/runtime/python/onnxruntime/funasr_onnx/paraformer_online_bin.py b/runtime/python/onnxruntime/funasr_onnx/paraformer_online_bin.py
index 7da5afc..6925960 100644
--- a/runtime/python/onnxruntime/funasr_onnx/paraformer_online_bin.py
+++ b/runtime/python/onnxruntime/funasr_onnx/paraformer_online_bin.py
@@ -55,8 +55,8 @@
"For the users in China, you could install with the command:\n" \
"\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
- model = AutoModel(model=cache_dir)
- model_dir = model.export(type="onnx", quantize=quantize, device="cpu")
+ model = AutoModel(model=model_dir)
+ model_dir = model.export(type="onnx", quantize=quantize)
config_file = os.path.join(model_dir, 'config.yaml')
cmvn_file = os.path.join(model_dir, 'am.mvn')
diff --git a/runtime/python/onnxruntime/funasr_onnx/punc_bin.py b/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
index 4e1014f..b1aca6e 100644
--- a/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
+++ b/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
@@ -55,8 +55,8 @@
"For the users in China, you could install with the command:\n" \
"\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
- model = AutoModel(model=cache_dir)
- model_dir = model.export(type="onnx", quantize=quantize, device="cpu")
+ model = AutoModel(model=model_dir)
+ model_dir = model.export(type="onnx", quantize=quantize)
config_file = os.path.join(model_dir, 'punc.yaml')
config = read_yaml(config_file)
diff --git a/runtime/python/onnxruntime/funasr_onnx/vad_bin.py b/runtime/python/onnxruntime/funasr_onnx/vad_bin.py
index af32b1d..384f377 100644
--- a/runtime/python/onnxruntime/funasr_onnx/vad_bin.py
+++ b/runtime/python/onnxruntime/funasr_onnx/vad_bin.py
@@ -61,8 +61,8 @@
"For the users in China, you could install with the command:\n" \
"\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
- model = AutoModel(model=cache_dir)
- model_dir = model.export(type="onnx", quantize=quantize, device="cpu")
+ model = AutoModel(model=model_dir)
+ model_dir = model.export(type="onnx", quantize=quantize)
config_file = os.path.join(model_dir, 'vad.yaml')
cmvn_file = os.path.join(model_dir, 'vad.mvn')
config = read_yaml(config_file)
@@ -225,8 +225,8 @@
"For the users in China, you could install with the command:\n" \
"\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
- model = AutoModel(model=cache_dir)
- model_dir = model.export(type="onnx", quantize=quantize, device="cpu")
+ model = AutoModel(model=model_dir)
+ model_dir = model.export(type="onnx", quantize=quantize)
config_file = os.path.join(model_dir, 'vad.yaml')
cmvn_file = os.path.join(model_dir, 'vad.mvn')
--
Gitblit v1.9.1