| | |
| | | init_model/ |
| | | *.tar.gz |
| | | test_local/ |
| | | RapidASR |
| | |
| | | import logging |
| | | import torch |
| | | |
| | | from funasr.bin.asr_inference_paraformer import Speech2Text |
| | | from funasr.export.models import get_model |
| | | import numpy as np |
| | | import random |
| | | |
| | | torch_version = float(".".join(torch.__version__.split(".")[:2])) |
| | | assert torch_version > 1.9 |
| | | |
| | | class ASRModelExportParaformer: |
| | | def __init__(self, cache_dir: Union[Path, str] = None, onnx: bool = True): |
| | |
| | | |
| | | def _export( |
| | | self, |
| | | model: Speech2Text, |
| | | model, |
| | | tag_name: str = None, |
| | | verbose: bool = False, |
| | | ): |
| | |
| | | os.path.join(path, f'{model.model_name}.onnx'), |
| | | verbose=verbose, |
| | | opset_version=14, |
| | | input_names=model.get_input_names(), |
| | | output_names=model.get_output_names(), |
| | | dynamic_axes=model.get_dynamic_axes() |
| | | ) |
| | | |
| | | |
| | | class ASRModelExport: |
| | | def __init__(self, cache_dir: Union[Path, str] = None, onnx: bool = True): |
| | | assert check_argument_types() |
| | | self.set_all_random_seed(0) |
| | | if cache_dir is None: |
| | | cache_dir = Path.home() / ".cache" / "export" |
| | | |
| | | self.cache_dir = Path(cache_dir) |
| | | self.export_config = dict( |
| | | feats_dim=560, |
| | | onnx=False, |
| | | ) |
| | | print("output dir: {}".format(self.cache_dir)) |
| | | self.onnx = onnx |
| | | |
| | | def _export( |
| | | self, |
| | | model: Speech2Text, |
| | | tag_name: str = None, |
| | | verbose: bool = False, |
| | | ): |
| | | |
| | | export_dir = self.cache_dir / tag_name.replace(' ', '-') |
| | | os.makedirs(export_dir, exist_ok=True) |
| | | |
| | | # export encoder1 |
| | | self.export_config["model_name"] = "model" |
| | | model = get_model( |
| | | model, |
| | | self.export_config, |
| | | ) |
| | | model.eval() |
| | | # self._export_onnx(model, verbose, export_dir) |
| | | if self.onnx: |
| | | self._export_onnx(model, verbose, export_dir) |
| | | else: |
| | | self._export_torchscripts(model, verbose, export_dir) |
| | | |
| | | print("output dir: {}".format(export_dir)) |
| | | |
| | | def _export_torchscripts(self, model, verbose, path, enc_size=None): |
| | | if enc_size: |
| | | dummy_input = model.get_dummy_inputs(enc_size) |
| | | else: |
| | | dummy_input = model.get_dummy_inputs_txt() |
| | | |
| | | # model_script = torch.jit.script(model) |
| | | model_script = torch.jit.trace(model, dummy_input) |
| | | model_script.save(os.path.join(path, f'{model.model_name}.torchscripts')) |
| | | |
| | | def set_all_random_seed(self, seed: int): |
| | | random.seed(seed) |
| | | np.random.seed(seed) |
| | | torch.random.manual_seed(seed) |
| | | |
| | | def export(self, |
| | | tag_name: str = 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch', |
| | | mode: str = 'paraformer', |
| | | ): |
| | | |
| | | model_dir = tag_name |
| | | if model_dir.startswith('damo/'): |
| | | from modelscope.hub.snapshot_download import snapshot_download |
| | | model_dir = snapshot_download(model_dir, cache_dir=self.cache_dir) |
| | | asr_train_config = os.path.join(model_dir, 'config.yaml') |
| | | asr_model_file = os.path.join(model_dir, 'model.pb') |
| | | cmvn_file = os.path.join(model_dir, 'am.mvn') |
| | | json_file = os.path.join(model_dir, 'configuration.json') |
| | | if mode is None: |
| | | import json |
| | | with open(json_file, 'r') as f: |
| | | config_data = json.load(f) |
| | | mode = config_data['model']['model_config']['mode'] |
| | | if mode.startswith('paraformer'): |
| | | from funasr.tasks.asr import ASRTaskParaformer as ASRTask |
| | | elif mode.startswith('uniasr'): |
| | | from funasr.tasks.asr import ASRTaskUniASR as ASRTask |
| | | |
| | | model, asr_train_args = ASRTask.build_model_from_file( |
| | | asr_train_config, asr_model_file, cmvn_file, 'cpu' |
| | | ) |
| | | self._export(model, tag_name) |
| | | |
| | | def _export_onnx(self, model, verbose, path, enc_size=None): |
| | | if enc_size: |
| | | dummy_input = model.get_dummy_inputs(enc_size) |
| | | else: |
| | | dummy_input = model.get_dummy_inputs() |
| | | |
| | | # model_script = torch.jit.script(model) |
| | | model_script = model # torch.jit.trace(model) |
| | | |
| | | torch.onnx.export( |
| | | model_script, |
| | | dummy_input, |
| | | os.path.join(path, f'{model.model_name}.onnx'), |
| | | verbose=verbose, |
| | | opset_version=12, |
| | | input_names=model.get_input_names(), |
| | | output_names=model.get_output_names(), |
| | | dynamic_axes=model.get_dynamic_axes() |
| | |
| | | return hidden, alphas, token_num_floor
|
| | |
|
| | |
|
| | | # @torch.jit.script
|
| | | # def cif(hidden, alphas, threshold: float):
|
| | | # batch_size, len_time, hidden_size = hidden.size()
|
| | | # threshold = torch.tensor([threshold], dtype=alphas.dtype).to(alphas.device)
|
| | | #
|
| | | # # loop varss
|
| | | # integrate = torch.zeros([batch_size], device=hidden.device)
|
| | | # frame = torch.zeros([batch_size, hidden_size], device=hidden.device)
|
| | | # # intermediate vars along time
|
| | | # list_fires = []
|
| | | # list_frames = []
|
| | | #
|
| | | # for t in range(len_time):
|
| | | # alpha = alphas[:, t]
|
| | | # distribution_completion = torch.ones([batch_size], device=hidden.device) - integrate
|
| | | #
|
| | | # integrate += alpha
|
| | | # list_fires.append(integrate)
|
| | | #
|
| | | # fire_place = integrate >= threshold
|
| | | # integrate = torch.where(fire_place,
|
| | | # integrate - torch.ones([batch_size], device=hidden.device),
|
| | | # integrate)
|
| | | # cur = torch.where(fire_place,
|
| | | # distribution_completion,
|
| | | # alpha)
|
| | | # remainds = alpha - cur
|
| | | #
|
| | | # frame += cur[:, None] * hidden[:, t, :]
|
| | | # list_frames.append(frame)
|
| | | # frame = torch.where(fire_place[:, None].repeat(1, hidden_size),
|
| | | # remainds[:, None] * hidden[:, t, :],
|
| | | # frame)
|
| | | #
|
| | | # fires = torch.stack(list_fires, 1)
|
| | | # frames = torch.stack(list_frames, 1)
|
| | | # list_ls = []
|
| | | # len_labels = torch.floor(alphas.sum(-1)).int()
|
| | | # max_label_len = len_labels.max()
|
| | | # for b in range(batch_size):
|
| | | # fire = fires[b, :]
|
| | | # l = torch.index_select(frames[b, :, :], 0, torch.nonzero(fire >= threshold).squeeze())
|
| | | # pad_l = torch.zeros([int(max_label_len - l.size(0)), int(hidden_size)], device=hidden.device)
|
| | | # list_ls.append(torch.cat([l, pad_l], 0))
|
| | | # return torch.stack(list_ls, 0), fires
|
| | |
|
| | |
|
| | | @torch.jit.script
|
| | | def cif(hidden, alphas, threshold: float):
|
| | | batch_size, len_time, hidden_size = hidden.size()
|
| | |
| | |
|
| | | fires = torch.stack(list_fires, 1)
|
| | | frames = torch.stack(list_frames, 1)
|
| | | # list_ls = []
|
| | | len_labels = torch.round(alphas.sum(-1)).type(torch.int32)
|
| | | # max_label_len = int(torch.max(len_labels).item())
|
| | | # print("type: {}".format(type(max_label_len)))
|
| | |
|
| | | fire_idxs = fires >= threshold
|
| | | frame_fires = torch.zeros_like(hidden)
|
| | | max_label_len = frames[0, fire_idxs[0]].size(0)
|
| | | for b in range(batch_size):
|
| | | # fire = fires[b, :]
|
| | | frame_fire = frames[b, fire_idxs[b]]
|
| | | frame_len = frame_fire.size(0)
|
| | | frame_fires[b, :frame_len, :] = frame_fire
|
| | |
| | | |
| | | class OrtInferSession(): |
| | | def __init__(self, model_file, device_id=-1): |
| | | device_id = str(device_id) |
| | | sess_opt = SessionOptions() |
| | | sess_opt.log_severity_level = 4 |
| | | sess_opt.enable_cpu_mem_arena = False |
| | |
| | | } |
| | | |
| | | EP_list = [] |
| | | if device_id != -1 and get_device() == 'GPU' \ |
| | | if device_id != "-1" and get_device() == 'GPU' \ |
| | | and cuda_ep in get_available_providers(): |
| | | EP_list = [(cuda_ep, cuda_provider_options)] |
| | | EP_list.append((cpu_ep, cpu_provider_options)) |
| | |
| | | sess_options=sess_opt, |
| | | providers=EP_list) |
| | | |
| | | if device_id != -1 and cuda_ep not in self.session.get_providers(): |
| | | if device_id != "-1" and cuda_ep not in self.session.get_providers(): |
| | | warnings.warn(f'{cuda_ep} is not avaiable for current env, the inference part is automatically shifted to be executed under {cpu_ep}.\n' |
| | | 'Please ensure the installed onnxruntime-gpu version matches your cuda and cudnn version, ' |
| | | 'you can check their relations from the offical web site: ' |