From 6c467e6f0abfc6d20d0621fbbf67b4dbd81776cc Mon Sep 17 00:00:00 2001
From: Shi Xian <40013335+R1ckShi@users.noreply.github.com>
Date: 星期二, 18 六月 2024 10:01:56 +0800
Subject: [PATCH] Merge pull request #1825 from modelscope/dev_libt
---
funasr/models/llm_asr_nar/model.py | 2
funasr/models/paraformer_streaming/model.py | 4
funasr/models/seaco_paraformer/export_meta.py | 7
examples/industrial_data_pretraining/paraformer/export.py | 19
funasr/models/transformer/model.py | 2
examples/industrial_data_pretraining/bicif_paraformer/export.py | 18
funasr/models/contextual_paraformer/export_meta.py | 19 +
runtime/python/onnxruntime/funasr_onnx/punc_bin.py | 2
funasr/datasets/sense_voice_datasets/datasets.py | 2
funasr/models/sense_voice/model.py | 8
runtime/python/libtorch/demo_paraformer.py | 11 +
funasr/models/whisper_lid/model.py | 2
funasr/utils/export_utils.py | 147 ++++++++++++
funasr/models/whisper/model.py | 9
runtime/python/libtorch/funasr_torch/utils/timestamp_utils.py | 2
runtime/python/onnxruntime/funasr_onnx/paraformer_online_bin.py | 2
funasr/datasets/llm_datasets_vicuna/datasets.py | 2
funasr/models/bicif_paraformer/export_meta.py | 3
funasr/models/sanm/attention.py | 2
funasr/models/llm_asr/model.py | 2
runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py | 9
funasr/auto/auto_model.py | 8
funasr/models/mfcca/mfcca_encoder.py | 2
funasr/datasets/llm_datasets_qwenaudio/datasets.py | 2
runtime/python/libtorch/funasr_torch/paraformer_bin.py | 239 ++++++++++++++++++++-
runtime/python/libtorch/demo_seaco_paraformer.py | 13 +
funasr/models/lcbnet/model.py | 2
funasr/models/bicif_paraformer/cif_predictor.py | 2
funasr/models/contextual_paraformer/decoder.py | 1
funasr/frontends/default.py | 1
funasr/models/paraformer/export_meta.py | 1
runtime/python/onnxruntime/funasr_onnx/vad_bin.py | 4
/dev/null | 17 -
runtime/python/libtorch/funasr_torch/utils/utils.py | 24 +
funasr/datasets/llm_datasets/datasets.py | 2
runtime/python/libtorch/demo_contextual_paraformer.py | 13 +
36 files changed, 477 insertions(+), 128 deletions(-)
diff --git a/examples/industrial_data_pretraining/bicif_paraformer/export.py b/examples/industrial_data_pretraining/bicif_paraformer/export.py
index 31098d2..44849b0 100644
--- a/examples/industrial_data_pretraining/bicif_paraformer/export.py
+++ b/examples/industrial_data_pretraining/bicif_paraformer/export.py
@@ -12,17 +12,17 @@
device="cpu",
)
-res = model.export(type="onnx", quantize=False)
+res = model.export(type="torchscripts", quantize=False)
print(res)
-# method2, inference from local path
-from funasr import AutoModel
+# # method2, inference from local path
+# from funasr import AutoModel
-model = AutoModel(
- model="/Users/zhifu/.cache/modelscope/hub/iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
- device="cpu",
-)
+# model = AutoModel(
+# model="/Users/zhifu/.cache/modelscope/hub/iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
+# device="cpu",
+# )
-res = model.export(type="onnx", quantize=False)
-print(res)
+# res = model.export(type="onnx", quantize=False)
+# print(res)
diff --git a/examples/industrial_data_pretraining/paraformer/export.py b/examples/industrial_data_pretraining/paraformer/export.py
index 19512c1..a91e9e4 100644
--- a/examples/industrial_data_pretraining/paraformer/export.py
+++ b/examples/industrial_data_pretraining/paraformer/export.py
@@ -10,19 +10,20 @@
from funasr import AutoModel
model = AutoModel(
- model="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
+ model="iic/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404",
)
-res = model.export(type="onnx", quantize=False)
+res = model.export(type="torchscripts", quantize=False)
+# res = model.export(type="bladedisc", input=f"{model.model_path}/example/asr_example.wav")
print(res)
-# method2, inference from local path
-from funasr import AutoModel
+# # method2, inference from local path
+# from funasr import AutoModel
-model = AutoModel(
- model="/Users/zhifu/.cache/modelscope/hub/iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
-)
+# model = AutoModel(
+# model="/Users/zhifu/.cache/modelscope/hub/iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
+# )
-res = model.export(type="onnx", quantize=False)
-print(res)
+# res = model.export(type="onnx", quantize=False)
+# print(res)
diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py
index 603c0a0..91e80d8 100644
--- a/funasr/auto/auto_model.py
+++ b/funasr/auto/auto_model.py
@@ -602,12 +602,6 @@
)
with torch.no_grad():
-
- if type == "onnx":
- export_dir = export_utils.export_onnx(model=model, data_in=data_list, **kwargs)
- else:
- export_dir = export_utils.export_torchscripts(
- model=model, data_in=data_list, **kwargs
- )
+ export_dir = export_utils.export(model=model, data_in=data_list, **kwargs)
return export_dir
diff --git a/funasr/datasets/llm_datasets/datasets.py b/funasr/datasets/llm_datasets/datasets.py
index b660554..61caded 100644
--- a/funasr/datasets/llm_datasets/datasets.py
+++ b/funasr/datasets/llm_datasets/datasets.py
@@ -64,8 +64,6 @@
def __getitem__(self, index):
item = self.index_ds[index]
- # import pdb;
- # pdb.set_trace()
source = item["source"]
data_src = load_audio_text_image_video(source, fs=self.fs)
if self.preprocessor_speech:
diff --git a/funasr/datasets/llm_datasets_qwenaudio/datasets.py b/funasr/datasets/llm_datasets_qwenaudio/datasets.py
index b56e624..569665a 100644
--- a/funasr/datasets/llm_datasets_qwenaudio/datasets.py
+++ b/funasr/datasets/llm_datasets_qwenaudio/datasets.py
@@ -66,8 +66,6 @@
def __getitem__(self, index):
item = self.index_ds[index]
- # import pdb;
- # pdb.set_trace()
source = item["source"]
data_src = load_audio_text_image_video(source, fs=self.fs)
if self.preprocessor_speech:
diff --git a/funasr/datasets/llm_datasets_vicuna/datasets.py b/funasr/datasets/llm_datasets_vicuna/datasets.py
index 04fa514..cde29a9 100644
--- a/funasr/datasets/llm_datasets_vicuna/datasets.py
+++ b/funasr/datasets/llm_datasets_vicuna/datasets.py
@@ -66,8 +66,6 @@
def __getitem__(self, index):
item = self.index_ds[index]
- # import pdb;
- # pdb.set_trace()
source = item["source"]
data_src = load_audio_text_image_video(source, fs=self.fs)
if self.preprocessor_speech:
diff --git a/funasr/datasets/sense_voice_datasets/datasets.py b/funasr/datasets/sense_voice_datasets/datasets.py
index c0beda1..d4e14f2 100644
--- a/funasr/datasets/sense_voice_datasets/datasets.py
+++ b/funasr/datasets/sense_voice_datasets/datasets.py
@@ -72,8 +72,6 @@
return len(self.index_ds)
def __getitem__(self, index):
- # import pdb;
- # pdb.set_trace()
output = None
for idx in range(self.retry):
diff --git a/funasr/frontends/default.py b/funasr/frontends/default.py
index 462da86..68bd6fb 100644
--- a/funasr/frontends/default.py
+++ b/funasr/frontends/default.py
@@ -235,7 +235,6 @@
self, input: torch.Tensor, input_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
# 1. Domain-conversion: e.g. Stft: time -> time-freq
- # import pdb;pdb.set_trace()
if self.stft is not None:
input_stft, feats_lens = self._compute_stft(input, input_lengths)
else:
diff --git a/funasr/models/bicif_paraformer/cif_predictor.py b/funasr/models/bicif_paraformer/cif_predictor.py
index 3739c76..ca98cdc 100644
--- a/funasr/models/bicif_paraformer/cif_predictor.py
+++ b/funasr/models/bicif_paraformer/cif_predictor.py
@@ -198,7 +198,7 @@
output2 = self.upsample_cnn(_output)
output2 = output2.transpose(1, 2)
output2, _ = self.self_attn(output2, mask)
- # import pdb; pdb.set_trace()
+
alphas2 = torch.sigmoid(self.cif_output2(output2))
alphas2 = torch.nn.functional.relu(alphas2 * self.smooth_factor2 - self.noise_threshold2)
# repeat the mask in T demension to match the upsampled length
diff --git a/funasr/models/bicif_paraformer/export_meta.py b/funasr/models/bicif_paraformer/export_meta.py
index e9d0a25..75171f4 100644
--- a/funasr/models/bicif_paraformer/export_meta.py
+++ b/funasr/models/bicif_paraformer/export_meta.py
@@ -29,7 +29,8 @@
model.export_input_names = types.MethodType(export_input_names, model)
model.export_output_names = types.MethodType(export_output_names, model)
model.export_dynamic_axes = types.MethodType(export_dynamic_axes, model)
- model.export_name = types.MethodType(export_name, model)
+
+ model.export_name = "model"
return model
diff --git a/funasr/models/contextual_paraformer/decoder.py b/funasr/models/contextual_paraformer/decoder.py
index 0b30c99..ba2ce9a 100644
--- a/funasr/models/contextual_paraformer/decoder.py
+++ b/funasr/models/contextual_paraformer/decoder.py
@@ -424,7 +424,6 @@
# contextual_mask = myutils.sequence_mask(contextual_length, device=memory.device)[:, None, :]
contextual_mask = self.make_pad_mask(contextual_length)
contextual_mask, _ = self.prepare_mask(contextual_mask)
- # import pdb; pdb.set_trace()
contextual_mask = contextual_mask.transpose(2, 1).unsqueeze(1)
cx, tgt_mask, _, _, _ = self.bias_decoder(
x_self_attn, tgt_mask, bias_embed, memory_mask=contextual_mask
diff --git a/funasr/models/contextual_paraformer/export_meta.py b/funasr/models/contextual_paraformer/export_meta.py
index 602057f..9d3a63b 100644
--- a/funasr/models/contextual_paraformer/export_meta.py
+++ b/funasr/models/contextual_paraformer/export_meta.py
@@ -16,6 +16,21 @@
self.embedding = model.bias_embed
model.bias_encoder.batch_first = False
self.bias_encoder = model.bias_encoder
+
+ def export_dummy_inputs(self):
+ hotword = torch.tensor(
+ [
+ [10, 11, 12, 13, 14, 10, 11, 12, 13, 14],
+ [100, 101, 0, 0, 0, 0, 0, 0, 0, 0],
+ [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [10, 11, 12, 13, 14, 10, 11, 12, 13, 14],
+ [100, 101, 0, 0, 0, 0, 0, 0, 0, 0],
+ [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ ],
+ dtype=torch.int32,
+ )
+ # hotword_length = torch.tensor([10, 2, 1], dtype=torch.int32)
+ return (hotword)
def export_rebuild_model(model, **kwargs):
@@ -59,7 +74,9 @@
backbone_model.export_dynamic_axes = types.MethodType(
export_backbone_dynamic_axes, backbone_model
)
- backbone_model.export_name = types.MethodType(export_backbone_name, backbone_model)
+
+ embedder_model.export_name = "model_eb"
+ backbone_model.export_name = "model"
return backbone_model, embedder_model
diff --git a/funasr/models/lcbnet/model.py b/funasr/models/lcbnet/model.py
index d3df25a..7b2038e 100644
--- a/funasr/models/lcbnet/model.py
+++ b/funasr/models/lcbnet/model.py
@@ -23,8 +23,6 @@
from funasr.utils.datadir_writer import DatadirWriter
from funasr.register import tables
-import pdb
-
@tables.register("model_classes", "LCBNet")
class LCBNet(nn.Module):
diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py
index 519918c..c209026 100644
--- a/funasr/models/llm_asr/model.py
+++ b/funasr/models/llm_asr/model.py
@@ -166,8 +166,6 @@
text: (Batch, Length)
text_lengths: (Batch,)
"""
- # import pdb;
- # pdb.set_trace()
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
diff --git a/funasr/models/llm_asr_nar/model.py b/funasr/models/llm_asr_nar/model.py
index 8c0c3ff..192c199 100644
--- a/funasr/models/llm_asr_nar/model.py
+++ b/funasr/models/llm_asr_nar/model.py
@@ -166,8 +166,6 @@
text: (Batch, Length)
text_lengths: (Batch,)
"""
- # import pdb;
- # pdb.set_trace()
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
diff --git a/funasr/models/mfcca/mfcca_encoder.py b/funasr/models/mfcca/mfcca_encoder.py
index 19a6df9..a0bb58e 100644
--- a/funasr/models/mfcca/mfcca_encoder.py
+++ b/funasr/models/mfcca/mfcca_encoder.py
@@ -34,7 +34,6 @@
from funasr.models.transformer.utils.subsampling import TooShortUttError
from funasr.models.transformer.utils.subsampling import check_short_utt
from funasr.models.encoder.abs_encoder import AbsEncoder
-import pdb
import math
@@ -363,7 +362,6 @@
t_leng = xs_pad.size(1)
d_dim = xs_pad.size(2)
xs_pad = xs_pad.reshape(-1, channel_size, t_leng, d_dim)
- # pdb.set_trace()
if channel_size < 8:
repeat_num = math.ceil(8 / channel_size)
xs_pad = xs_pad.repeat(1, repeat_num, 1, 1)[:, 0:8, :, :]
diff --git a/funasr/models/paraformer/export_meta.py b/funasr/models/paraformer/export_meta.py
index 5c1b6c0..db93855 100644
--- a/funasr/models/paraformer/export_meta.py
+++ b/funasr/models/paraformer/export_meta.py
@@ -31,6 +31,7 @@
model.export_dynamic_axes = types.MethodType(export_dynamic_axes, model)
model.export_name = types.MethodType(export_name, model)
+ model.export_name = 'model'
return model
diff --git a/funasr/models/paraformer_streaming/model.py b/funasr/models/paraformer_streaming/model.py
index f287614..16021ce 100644
--- a/funasr/models/paraformer_streaming/model.py
+++ b/funasr/models/paraformer_streaming/model.py
@@ -50,8 +50,6 @@
super().__init__(*args, **kwargs)
- # import pdb;
- # pdb.set_trace()
self.sampling_ratio = kwargs.get("sampling_ratio", 0.2)
self.scama_mask = None
@@ -83,8 +81,6 @@
text: (Batch, Length)
text_lengths: (Batch,)
"""
- # import pdb;
- # pdb.set_trace()
decoding_ind = kwargs.get("decoding_ind")
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
diff --git a/funasr/models/sanm/attention.py b/funasr/models/sanm/attention.py
index 08f7dc7..c7e8a8e 100644
--- a/funasr/models/sanm/attention.py
+++ b/funasr/models/sanm/attention.py
@@ -780,7 +780,7 @@
return q, k, v
def forward_attention(self, value, scores, mask, ret_attn):
- scores = scores + mask
+ scores = scores + mask.to(scores.device)
self.attn = torch.softmax(scores, dim=-1)
context_layer = torch.matmul(self.attn, value) # (batch, head, time1, d_k)
diff --git a/funasr/models/seaco_paraformer/export_meta.py b/funasr/models/seaco_paraformer/export_meta.py
index a246a29..d03ebd8 100644
--- a/funasr/models/seaco_paraformer/export_meta.py
+++ b/funasr/models/seaco_paraformer/export_meta.py
@@ -109,7 +109,9 @@
backbone_model.export_dynamic_axes = types.MethodType(
export_backbone_dynamic_axes, backbone_model
)
- backbone_model.export_name = types.MethodType(export_backbone_name, backbone_model)
+
+ embedder_model.export_name = "model_eb"
+ backbone_model.export_name = "model"
return backbone_model, embedder_model
@@ -198,6 +200,3 @@
"us_cif_peak": {0: "batch_size", 1: "alphas_length"},
}
-
-def export_backbone_name(self):
- return "model.onnx"
diff --git a/funasr/models/sense_voice/model.py b/funasr/models/sense_voice/model.py
index 22272ee..97f1b19 100644
--- a/funasr/models/sense_voice/model.py
+++ b/funasr/models/sense_voice/model.py
@@ -73,8 +73,6 @@
):
target_mask = kwargs.get("target_mask", None)
- # import pdb;
- # pdb.set_trace()
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
@@ -303,8 +301,6 @@
):
target_mask = kwargs.get("target_mask", None)
- # import pdb;
- # pdb.set_trace()
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
@@ -648,8 +644,6 @@
):
target_mask = kwargs.get("target_mask", None)
- # import pdb;
- # pdb.set_trace()
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
@@ -1052,8 +1046,6 @@
):
target_mask = kwargs.get("target_mask", None)
- # import pdb;
- # pdb.set_trace()
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
diff --git a/funasr/models/transformer/model.py b/funasr/models/transformer/model.py
index 0d5ed23..adfd525 100644
--- a/funasr/models/transformer/model.py
+++ b/funasr/models/transformer/model.py
@@ -145,8 +145,6 @@
text: (Batch, Length)
text_lengths: (Batch,)
"""
- # import pdb;
- # pdb.set_trace()
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
diff --git a/funasr/models/whisper/model.py b/funasr/models/whisper/model.py
index 8e9245a..a332100 100644
--- a/funasr/models/whisper/model.py
+++ b/funasr/models/whisper/model.py
@@ -7,7 +7,10 @@
import torch.nn.functional as F
from torch import Tensor
from torch import nn
+
import whisper
+# import whisper_timestamped as whisper
+
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
from funasr.register import tables
@@ -108,8 +111,10 @@
# decode the audio
options = whisper.DecodingOptions(**kwargs.get("DecodingOptions", {}))
- result = whisper.decode(self.model, speech, options)
-
+
+ result = whisper.decode(self.model, speech, language='english')
+ # result = whisper.transcribe(self.model, speech)
+
results = []
result_i = {"key": key[0], "text": result.text}
diff --git a/funasr/models/whisper_lid/model.py b/funasr/models/whisper_lid/model.py
index 0701f61..02cd373 100644
--- a/funasr/models/whisper_lid/model.py
+++ b/funasr/models/whisper_lid/model.py
@@ -140,8 +140,6 @@
text: (Batch, Length)
text_lengths: (Batch,)
"""
- # import pdb;
- # pdb.set_trace()
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
diff --git a/funasr/utils/export_utils.py b/funasr/utils/export_utils.py
index bc79539..5a98847 100644
--- a/funasr/utils/export_utils.py
+++ b/funasr/utils/export_utils.py
@@ -1,8 +1,14 @@
import os
import torch
+import functools
+
+try:
+ import torch_blade
+except Exception as e:
+ print(f"failed to load torch_blade: {e}")
-def export_onnx(model, data_in=None, quantize: bool = False, opset_version: int = 14, **kwargs):
+def export(model, data_in=None, quantize: bool = False, opset_version: int = 14, type='onnx', **kwargs):
model_scripts = model.export(**kwargs)
export_dir = kwargs.get("output_dir", os.path.dirname(kwargs.get("init_param")))
os.makedirs(export_dir, exist_ok=True)
@@ -11,14 +17,32 @@
model_scripts = (model_scripts,)
for m in model_scripts:
m.eval()
- _onnx(
- m,
- data_in=data_in,
- quantize=quantize,
- opset_version=opset_version,
- export_dir=export_dir,
- **kwargs
- )
+ if type == 'onnx':
+ _onnx(
+ m,
+ data_in=data_in,
+ quantize=quantize,
+ opset_version=opset_version,
+ export_dir=export_dir,
+ **kwargs
+ )
+ elif type == 'torchscripts':
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
+ print("Exporting torchscripts on device {}".format(device))
+ _torchscripts(
+ m,
+ path=export_dir,
+ device=device
+ )
+ elif type == "bladedisc":
+ assert (
+ torch.cuda.is_available()
+ ), "Currently bladedisc optimization for FunASR only supports GPU"
+ # bladedisc only optimizes encoder/decoder modules
+ if hasattr(m, "encoder") and hasattr(m, "decoder"):
+ _bladedisc_opt_for_encdec(m, path=export_dir, enable_fp16=True)
+ else:
+ _torchscripts(m, path=export_dir, device="cuda")
print("output dir: {}".format(export_dir))
return export_dir
@@ -37,7 +61,7 @@
verbose = kwargs.get("verbose", False)
- export_name = model.export_name() if hasattr(model, "export_name") else "model.onnx"
+ export_name = model.export_name + '.onnx'
model_path = os.path.join(export_dir, export_name)
torch.onnx.export(
model,
@@ -70,3 +94,106 @@
weight_type=QuantType.QUInt8,
nodes_to_exclude=nodes_to_exclude,
)
+
+
+def _torchscripts(model, path, device='cuda'):
+ dummy_input = model.export_dummy_inputs()
+
+ if device == 'cuda':
+ model = model.cuda()
+ if isinstance(dummy_input, torch.Tensor):
+ dummy_input = dummy_input.cuda()
+ else:
+ dummy_input = tuple([i.cuda() for i in dummy_input])
+
+ model_script = torch.jit.trace(model, dummy_input)
+ model_script.save(os.path.join(path, f'{model.export_name}.torchscripts'))
+
+
+def _bladedisc_opt(model, model_inputs, enable_fp16=True):
+ model = model.eval()
+ torch_config = torch_blade.config.Config()
+ torch_config.enable_fp16 = enable_fp16
+ with torch.no_grad(), torch_config:
+ opt_model = torch_blade.optimize(
+ model,
+ allow_tracing=True,
+ model_inputs=model_inputs,
+ )
+ return opt_model
+
+
+def _rescale_input_hook(m, x, scale):
+ if len(x) > 1:
+ return (x[0] / scale, *x[1:])
+ else:
+ return (x[0] / scale,)
+
+
+def _rescale_output_hook(m, x, y, scale):
+ if isinstance(y, tuple):
+ return (y[0] / scale, *y[1:])
+ else:
+ return y / scale
+
+
+def _rescale_encoder_model(model, input_data):
+ # Calculate absmax
+ absmax = torch.tensor(0).cuda()
+
+ def stat_input_hook(m, x, y):
+ val = x[0] if isinstance(x, tuple) else x
+ absmax.copy_(torch.max(absmax, val.detach().abs().max()))
+
+ encoders = model.encoder.model.encoders
+ hooks = [m.register_forward_hook(stat_input_hook) for m in encoders]
+ model = model.cuda()
+ model(*input_data)
+ for h in hooks:
+ h.remove()
+
+ # Rescale encoder modules
+ fp16_scale = int(2 * absmax // 65536)
+ print(f"rescale encoder modules with factor={fp16_scale}")
+ model.encoder.model.encoders0.register_forward_pre_hook(
+ functools.partial(_rescale_input_hook, scale=fp16_scale),
+ )
+ for name, m in model.encoder.model.named_modules():
+ if name.endswith("self_attn"):
+ m.register_forward_hook(
+ functools.partial(_rescale_output_hook, scale=fp16_scale)
+ )
+ if name.endswith("feed_forward.w_2"):
+ state_dict = {k: v / fp16_scale for k, v in m.state_dict().items()}
+ m.load_state_dict(state_dict)
+
+
+def _bladedisc_opt_for_encdec(model, path, enable_fp16):
+ # Get input data
+ # TODO: better to use real data
+ input_data = model.export_dummy_inputs()
+ if isinstance(input_data, torch.Tensor):
+ input_data = input_data.cuda()
+ else:
+ input_data = tuple([i.cuda() for i in input_data])
+
+ # Get input data for decoder module
+ decoder_inputs = list()
+
+ def get_input_hook(m, x):
+ decoder_inputs.extend(list(x))
+
+ hook = model.decoder.register_forward_pre_hook(get_input_hook)
+ model = model.cuda()
+ model(*input_data)
+ hook.remove()
+
+ # Prevent FP16 overflow
+ if enable_fp16:
+ _rescale_encoder_model(model, input_data)
+
+ # Export and optimize encoder/decoder modules
+ model.encoder = _bladedisc_opt(model.encoder, input_data[:2])
+ model.decoder = _bladedisc_opt(model.decoder, tuple(decoder_inputs))
+ model_script = torch.jit.trace(model, input_data)
+ model_script.save(os.path.join(path, f"{model.export_name}_blade.torchscripts"))
diff --git a/runtime/python/libtorch/demo.py b/runtime/python/libtorch/demo.py
deleted file mode 100644
index 1ef9a20..0000000
--- a/runtime/python/libtorch/demo.py
+++ /dev/null
@@ -1,17 +0,0 @@
-from funasr_torch import Paraformer
-
-
-model_dir = (
- "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
-)
-
-model = Paraformer(model_dir, batch_size=1) # cpu
-# model = Paraformer(model_dir, batch_size=1, device_id=0) # gpu
-
-# when using paraformer-large-vad-punc model, you can set plot_timestamp_to="./xx.png" to get figure of alignment besides timestamps
-# model = Paraformer(model_dir, batch_size=1, plot_timestamp_to="test.png")
-
-wav_path = "YourPath/xx.wav"
-
-result = model(wav_path)
-print(result)
diff --git a/runtime/python/libtorch/demo_contextual_paraformer.py b/runtime/python/libtorch/demo_contextual_paraformer.py
new file mode 100644
index 0000000..306981c
--- /dev/null
+++ b/runtime/python/libtorch/demo_contextual_paraformer.py
@@ -0,0 +1,13 @@
+import torch
+from pathlib import Path
+from funasr_torch.paraformer_bin import ContextualParaformer
+
+model_dir = "iic/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404"
+device_id = 0 if torch.cuda.is_available() else -1
+model = ContextualParaformer(model_dir, batch_size=1, device_id=device_id) # gpu
+
+wav_path = "{}/.cache/modelscope/hub/{}/example/asr_example.wav".format(Path.home(), model_dir)
+hotwords = "浣犵殑鐑瘝 榄旀惌"
+
+result = model(wav_path, hotwords)
+print(result)
diff --git a/runtime/python/libtorch/demo_paraformer.py b/runtime/python/libtorch/demo_paraformer.py
new file mode 100644
index 0000000..62355e2
--- /dev/null
+++ b/runtime/python/libtorch/demo_paraformer.py
@@ -0,0 +1,11 @@
+from pathlib import Path
+from funasr_torch.paraformer_bin import Paraformer
+
+model_dir = "iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
+model = Paraformer(model_dir, batch_size=1) # cpu
+# model = Paraformer(model_dir, batch_size=1, device_id=0) # gpu
+
+wav_path = "{}/.cache/modelscope/hub/{}/example/asr_example.wav".format(Path.home(), model_dir)
+
+result = model(wav_path)
+print(result)
diff --git a/runtime/python/libtorch/demo_seaco_paraformer.py b/runtime/python/libtorch/demo_seaco_paraformer.py
new file mode 100644
index 0000000..ad28bfe
--- /dev/null
+++ b/runtime/python/libtorch/demo_seaco_paraformer.py
@@ -0,0 +1,13 @@
+import torch
+from pathlib import Path
+from funasr_torch.paraformer_bin import SeacoParaformer
+
+model_dir = "iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
+device_id = 0 if torch.cuda.is_available() else -1
+model = SeacoParaformer(model_dir, batch_size=1, device_id=device_id) # gpu
+
+wav_path = "{}/.cache/modelscope/hub/{}/example/asr_example.wav".format(Path.home(), model_dir)
+hotwords = "浣犵殑鐑瘝 榄旀惌"
+
+result = model(wav_path, hotwords)
+print(result)
diff --git a/runtime/python/libtorch/funasr_torch/paraformer_bin.py b/runtime/python/libtorch/funasr_torch/paraformer_bin.py
index 68886df..5fa3cc9 100644
--- a/runtime/python/libtorch/funasr_torch/paraformer_bin.py
+++ b/runtime/python/libtorch/funasr_torch/paraformer_bin.py
@@ -1,23 +1,29 @@
# -*- encoding: utf-8 -*-
+import json
+import copy
+import torch
import os.path
+import librosa
+import numpy as np
from pathlib import Path
from typing import List, Union, Tuple
-import copy
-import librosa
-import numpy as np
-
-from .utils.utils import CharTokenizer, Hypothesis, TokenIDConverter, get_logger, read_yaml
-from .utils.postprocess_utils import sentence_postprocess
+from .utils.utils import pad_list
from .utils.frontend import WavFrontend
from .utils.timestamp_utils import time_stamp_lfr6_onnx
+from .utils.postprocess_utils import sentence_postprocess
+from .utils.utils import CharTokenizer, Hypothesis, TokenIDConverter, get_logger, read_yaml
logging = get_logger()
-import torch
-
class Paraformer:
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
+ https://arxiv.org/abs/2206.08317
+ """
+
def __init__(
self,
model_dir: Union[str, Path] = None,
@@ -25,20 +31,42 @@
device_id: Union[str, int] = "-1",
plot_timestamp_to: str = "",
quantize: bool = False,
- intra_op_num_threads: int = 1,
+ cache_dir: str = None,
+ **kwargs,
):
-
if not Path(model_dir).exists():
- raise FileNotFoundError(f"{model_dir} does not exist.")
+ try:
+ from modelscope.hub.snapshot_download import snapshot_download
+ except:
+ raise "You are exporting model from modelscope, please install modelscope and try it again. To install modelscope, you could:\n" "\npip3 install -U modelscope\n" "For the users in China, you could install with the command:\n" "\npip3 install -U modelscope -i https://mirror.sjtu.edu.cn/pypi/web/simple"
+ try:
+ model_dir = snapshot_download(model_dir, cache_dir=cache_dir)
+ except:
+ raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(
+ model_dir
+ )
model_file = os.path.join(model_dir, "model.torchscripts")
if quantize:
model_file = os.path.join(model_dir, "model_quant.torchscripts")
+ if not os.path.exists(model_file):
+ print(".torchscripts does not exist, begin to export torchscripts")
+ try:
+ from funasr import AutoModel
+ except:
+ raise "You are exporting onnx, please install funasr and try it again. To install funasr, you could:\n" "\npip3 install -U funasr\n" "For the users in China, you could install with the command:\n" "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
+
+ model = AutoModel(model=model_dir)
+ model_dir = model.export(type="torchscript", quantize=quantize, **kwargs)
+
config_file = os.path.join(model_dir, "config.yaml")
cmvn_file = os.path.join(model_dir, "am.mvn")
config = read_yaml(config_file)
+ token_list = os.path.join(model_dir, "tokens.json")
+ with open(token_list, "r", encoding="utf-8") as f:
+ token_list = json.load(f)
- self.converter = TokenIDConverter(config["token_list"])
+ self.converter = TokenIDConverter(token_list)
self.tokenizer = CharTokenizer()
self.frontend = WavFrontend(cmvn_file=cmvn_file, **config["frontend_conf"])
self.ort_infer = torch.jit.load(model_file)
@@ -49,6 +77,10 @@
self.pred_bias = config["model_conf"]["predictor_bias"]
else:
self.pred_bias = 0
+ if "lang" in config:
+ self.language = config["lang"]
+ else:
+ self.language = None
def __call__(self, wav_content: Union[str, np.ndarray, List[str]], **kwargs) -> List:
waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
@@ -203,3 +235,186 @@
token = token[: valid_token_num - self.pred_bias]
# texts = sentence_postprocess(token)
return token
+
+
+class ContextualParaformer(Paraformer):
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
+ https://arxiv.org/abs/2206.08317
+ """
+
+ def __init__(
+ self,
+ model_dir: Union[str, Path] = None,
+ batch_size: int = 1,
+ device_id: Union[str, int] = "-1",
+ plot_timestamp_to: str = "",
+ quantize: bool = False,
+ cache_dir: str = None,
+ **kwargs,
+ ):
+
+ if not Path(model_dir).exists():
+ try:
+ from modelscope.hub.snapshot_download import snapshot_download
+ except:
+ raise "You are exporting model from modelscope, please install modelscope and try it again. To install modelscope, you could:\n" "\npip3 install -U modelscope\n" "For the users in China, you could install with the command:\n" "\npip3 install -U modelscope -i https://mirror.sjtu.edu.cn/pypi/web/simple"
+ try:
+ model_dir = snapshot_download(model_dir, cache_dir=cache_dir)
+ except:
+ raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(
+ model_dir
+ )
+
+ if quantize:
+ model_bb_file = os.path.join(model_dir, "model_bb_quant.torchscripts")
+ model_eb_file = os.path.join(model_dir, "model_eb_quant.torchscripts")
+ else:
+ model_bb_file = os.path.join(model_dir, "model_bb.torchscripts")
+ model_eb_file = os.path.join(model_dir, "model_eb.torchscripts")
+
+ if not (os.path.exists(model_eb_file) and os.path.exists(model_bb_file)):
+ print(".onnx does not exist, begin to export onnx")
+ try:
+ from funasr import AutoModel
+ except:
+ raise "You are exporting onnx, please install funasr and try it again. To install funasr, you could:\n" "\npip3 install -U funasr\n" "For the users in China, you could install with the command:\n" "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
+
+ model = AutoModel(model=model_dir)
+ model_dir = model.export(type="torchscripts", quantize=quantize, **kwargs)
+
+ config_file = os.path.join(model_dir, "config.yaml")
+ cmvn_file = os.path.join(model_dir, "am.mvn")
+ config = read_yaml(config_file)
+ token_list = os.path.join(model_dir, "tokens.json")
+ with open(token_list, "r", encoding="utf-8") as f:
+ token_list = json.load(f)
+
+ # revert token_list into vocab dict
+ self.vocab = {}
+ for i, token in enumerate(token_list):
+ self.vocab[token] = i
+
+ self.converter = TokenIDConverter(token_list)
+ self.tokenizer = CharTokenizer()
+ self.frontend = WavFrontend(cmvn_file=cmvn_file, **config["frontend_conf"])
+
+ self.ort_infer_bb = torch.jit.load(model_bb_file)
+ self.ort_infer_eb = torch.jit.load(model_eb_file)
+ self.device_id = device_id
+
+ self.batch_size = batch_size
+ self.plot_timestamp_to = plot_timestamp_to
+ if "predictor_bias" in config["model_conf"].keys():
+ self.pred_bias = config["model_conf"]["predictor_bias"]
+ else:
+ self.pred_bias = 0
+
+ def __call__(
+ self, wav_content: Union[str, np.ndarray, List[str]], hotwords: str, **kwargs
+ ) -> List:
+ # make hotword list
+ hotwords, hotwords_length = self.proc_hotword(hotwords)
+ if int(self.device_id) != -1:
+ bias_embed = self.eb_infer(hotwords.cuda())
+ else:
+ bias_embed = self.eb_infer(hotwords)
+ # index from bias_embed
+ bias_embed = torch.transpose(bias_embed, 0, 1)
+ _ind = np.arange(0, len(hotwords)).tolist()
+ bias_embed = bias_embed[_ind, hotwords_length.tolist()]
+ waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
+ waveform_nums = len(waveform_list)
+ asr_res = []
+ for beg_idx in range(0, waveform_nums, self.batch_size):
+ end_idx = min(waveform_nums, beg_idx + self.batch_size)
+ feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx])
+ bias_embed = torch.unsqueeze(bias_embed, 0).repeat(feats.shape[0], 1, 1)
+ try:
+ with torch.no_grad():
+ if int(self.device_id) == -1:
+ outputs = self.bb_infer(feats, feats_len, bias_embed)
+ am_scores, valid_token_lens = outputs[0], outputs[1]
+ else:
+ outputs = self.bb_infer(feats.cuda(), feats_len.cuda(), bias_embed.cuda())
+ am_scores, valid_token_lens = outputs[0].cpu(), outputs[1].cpu()
+ except:
+ # logging.warning(traceback.format_exc())
+ logging.warning("input wav is silence or noise")
+ preds = [""]
+ else:
+ preds = self.decode(am_scores, valid_token_lens)
+ for pred in preds:
+ pred = sentence_postprocess(pred)
+ asr_res.append({"preds": pred})
+ return asr_res
+
+ def proc_hotword(self, hotwords):
+ hotwords = hotwords.split(" ")
+ hotwords_length = [len(i) - 1 for i in hotwords]
+ hotwords_length.append(0)
+ hotwords_length = np.array(hotwords_length)
+
+ # hotwords.append('<s>')
+ def word_map(word):
+ hotwords = []
+ for c in word:
+ if c not in self.vocab.keys():
+ hotwords.append(8403)
+ logging.warning(
+ "oov character {} found in hotword {}, replaced by <unk>".format(c, word)
+ )
+ else:
+ hotwords.append(self.vocab[c])
+ return np.array(hotwords)
+
+ hotword_int = [word_map(i) for i in hotwords]
+ hotword_int.append(np.array([1]))
+ hotwords = pad_list(hotword_int, pad_value=0, max_len=10)
+ return torch.tensor(hotwords), hotwords_length
+
+ def bb_infer(
+ self, feats, feats_len, bias_embed
+ ):
+ outputs = self.ort_infer_bb(feats, feats_len, bias_embed)
+ return outputs
+
+ def eb_infer(self, hotwords):
+ outputs = self.ort_infer_eb(hotwords.long())
+ return outputs
+
+ def decode(self, am_scores: np.ndarray, token_nums: int) -> List[str]:
+ return [
+ self.decode_one(am_score, token_num)
+ for am_score, token_num in zip(am_scores, token_nums)
+ ]
+
+ def decode_one(self, am_score: np.ndarray, valid_token_num: int) -> List[str]:
+ yseq = am_score.argmax(axis=-1)
+ score = am_score.max(axis=-1)
+ score = np.sum(score, axis=-1)
+
+ # pad with mask tokens to ensure compatibility with sos/eos tokens
+ # asr_model.sos:1 asr_model.eos:2
+ yseq = np.array([1] + yseq.tolist() + [2])
+ hyp = Hypothesis(yseq=yseq, score=score)
+
+ # remove sos/eos and get results
+ last_pos = -1
+ token_int = hyp.yseq[1:last_pos].tolist()
+
+ # remove blank symbol id, which is assumed to be 0
+ token_int = list(filter(lambda x: x not in (0, 2), token_int))
+
+ # Change integer-ids to tokens
+ token = self.converter.ids2tokens(token_int)
+ token = token[: valid_token_num - self.pred_bias]
+ # texts = sentence_postprocess(token)
+ return token
+
+
+class SeacoParaformer(ContextualParaformer):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ # no difference with contextual_paraformer in method of calling onnx models
diff --git a/runtime/python/libtorch/funasr_torch/utils/timestamp_utils.py b/runtime/python/libtorch/funasr_torch/utils/timestamp_utils.py
index 5abbafe..a10d193 100644
--- a/runtime/python/libtorch/funasr_torch/utils/timestamp_utils.py
+++ b/runtime/python/libtorch/funasr_torch/utils/timestamp_utils.py
@@ -7,7 +7,7 @@
START_END_THRESHOLD = 5
MAX_TOKEN_DURATION = 30
TIME_RATE = 10.0 * 6 / 1000 / 3 # 3 times upsampled
- cif_peak = us_cif_peak.reshape(-1)
+ cif_peak = us_cif_peak.reshape(-1).cpu()
num_frames = cif_peak.shape[-1]
if char_list[-1] == "</s>":
char_list = char_list[:-1]
diff --git a/runtime/python/libtorch/funasr_torch/utils/utils.py b/runtime/python/libtorch/funasr_torch/utils/utils.py
index f85d4e9..ee43852 100644
--- a/runtime/python/libtorch/funasr_torch/utils/utils.py
+++ b/runtime/python/libtorch/funasr_torch/utils/utils.py
@@ -1,21 +1,25 @@
# -*- encoding: utf-8 -*-
-
-import functools
+import yaml
import logging
-import pickle
+import functools
+import numpy as np
from pathlib import Path
from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union
-import numpy as np
-import yaml
-
-
-import warnings
-
root_dir = Path(__file__).resolve().parent
-
logger_initialized = {}
+def pad_list(xs, pad_value, max_len=None):
+ n_batch = len(xs)
+ if max_len is None:
+ max_len = max(x.size(0) for x in xs)
+ # pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
+ # numpy format
+ pad = (np.zeros((n_batch, max_len)) + pad_value).astype(np.int32)
+ for i in range(n_batch):
+ pad[i, : xs[i].shape[0]] = xs[i]
+
+ return pad
class TokenIDConverter:
def __init__(
diff --git a/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py b/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py
index 2cd43a8..a60b6d6 100644
--- a/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py
+++ b/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py
@@ -62,7 +62,7 @@
if quantize:
model_file = os.path.join(model_dir, "model_quant.onnx")
if not os.path.exists(model_file):
- print(".onnx is not exist, begin to export onnx")
+ print(".onnx does not exist, begin to export onnx")
try:
from funasr import AutoModel
except:
@@ -285,7 +285,7 @@
model_eb_file = os.path.join(model_dir, "model_eb.onnx")
if not (os.path.exists(model_eb_file) and os.path.exists(model_bb_file)):
- print(".onnx is not exist, begin to export onnx")
+ print(".onnx does not exist, begin to export onnx")
try:
from funasr import AutoModel
except:
@@ -331,7 +331,6 @@
# ) -> List:
# make hotword list
hotwords, hotwords_length = self.proc_hotword(hotwords)
- # import pdb; pdb.set_trace()
[bias_embed] = self.eb_infer(hotwords, hotwords_length)
# index from bias_embed
bias_embed = bias_embed.transpose(1, 0, 2)
@@ -411,10 +410,10 @@
return np.array(hotwords)
hotword_int = [word_map(i) for i in hotwords]
- # import pdb; pdb.set_trace()
+
hotword_int.append(np.array([1]))
hotwords = pad_list(hotword_int, pad_value=0, max_len=10)
- # import pdb; pdb.set_trace()
+
return hotwords, hotwords_length
def bb_infer(
diff --git a/runtime/python/onnxruntime/funasr_onnx/paraformer_online_bin.py b/runtime/python/onnxruntime/funasr_onnx/paraformer_online_bin.py
index 9b68b2f..ddd857d 100644
--- a/runtime/python/onnxruntime/funasr_onnx/paraformer_online_bin.py
+++ b/runtime/python/onnxruntime/funasr_onnx/paraformer_online_bin.py
@@ -54,7 +54,7 @@
encoder_model_file = os.path.join(model_dir, "model_quant.onnx")
decoder_model_file = os.path.join(model_dir, "decoder_quant.onnx")
if not os.path.exists(encoder_model_file) or not os.path.exists(decoder_model_file):
- print(".onnx is not exist, begin to export onnx")
+ print(".onnx does not exist, begin to export onnx")
try:
from funasr import AutoModel
except:
diff --git a/runtime/python/onnxruntime/funasr_onnx/punc_bin.py b/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
index 6208c09..ba55186 100644
--- a/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
+++ b/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
@@ -52,7 +52,7 @@
if quantize:
model_file = os.path.join(model_dir, "model_quant.onnx")
if not os.path.exists(model_file):
- print(".onnx is not exist, begin to export onnx")
+ print(".onnx does not exist, begin to export onnx")
try:
from funasr import AutoModel
except:
diff --git a/runtime/python/onnxruntime/funasr_onnx/vad_bin.py b/runtime/python/onnxruntime/funasr_onnx/vad_bin.py
index c195bb3..92928a8 100644
--- a/runtime/python/onnxruntime/funasr_onnx/vad_bin.py
+++ b/runtime/python/onnxruntime/funasr_onnx/vad_bin.py
@@ -52,7 +52,7 @@
if quantize:
model_file = os.path.join(model_dir, "model_quant.onnx")
if not os.path.exists(model_file):
- print(".onnx is not exist, begin to export onnx")
+ print(".onnx does not exist, begin to export onnx")
try:
from funasr import AutoModel
except:
@@ -221,7 +221,7 @@
if quantize:
model_file = os.path.join(model_dir, "model_quant.onnx")
if not os.path.exists(model_file):
- print(".onnx is not exist, begin to export onnx")
+ print(".onnx does not exist, begin to export onnx")
try:
from funasr import AutoModel
except:
--
Gitblit v1.9.1