| | |
| | | from funasr.export.models import get_model |
| | | import numpy as np |
| | | import random |
| | | |
| | | from funasr.utils.types import str2bool |
| | | # torch_version = float(".".join(torch.__version__.split(".")[:2])) |
| | | # assert torch_version > 1.9 |
| | | |
| | |
| | | # using dummy inputs for a example |
| | | if self.audio_in is not None: |
| | | feats, feats_len = self.load_feats(self.audio_in) |
| | | for feat, len in zip(feats, feats_len): |
| | | m(feat, len) |
| | | for i, (feat, len) in enumerate(zip(feats, feats_len)): |
| | | with torch.no_grad(): |
| | | m(feat, len) |
| | | else: |
| | | dummy_input = model.get_dummy_inputs() |
| | | m(*dummy_input) |
| | |
| | | feats = [] |
| | | feats_len = [] |
| | | for line in wav_list: |
| | | name, path = line.strip().split() |
| | | path = line.strip() |
| | | waveform, sampling_rate = torchaudio.load(path) |
| | | if sampling_rate != self.frontend.fs: |
| | | waveform = torchaudio.transforms.Resample(orig_freq=sampling_rate, |
| | |
| | | model, asr_train_args = ASRTask.build_model_from_file( |
| | | asr_train_config, asr_model_file, cmvn_file, 'cpu' |
| | | ) |
| | | self.frontend = model.frontend |
| | | self._export(model, tag_name) |
| | | |
| | | |
| | |
| | | parser.add_argument('--model-name', type=str, required=True) |
| | | parser.add_argument('--export-dir', type=str, required=True) |
| | | parser.add_argument('--type', type=str, default='onnx', help='["onnx", "torch"]') |
| | | parser.add_argument('--quantize', action='store_true', help='export quantized model') |
| | | parser.add_argument('--quantize', type=str2bool, default=False, help='export quantized model') |
| | | parser.add_argument('--fallback-num', type=int, default=0, help='amp fallback number') |
| | | parser.add_argument('--audio_in', type=str, default=None, help='["wav", "wav.scp"]') |
| | | parser.add_argument('--calib_num', type=int, default=200, help='calib max num') |