| | |
| | | from funasr.export.models import get_model |
| | | import numpy as np |
| | | import random |
| | | from funasr.utils.types import str2bool |
| | | from funasr.utils.types import str2bool, str2triple_str |
| | | # torch_version = float(".".join(torch.__version__.split(".")[:2])) |
| | | # assert torch_version > 1.9 |
| | | |
| | |
| | | if __name__ == '__main__': |
| | | import argparse |
| | | parser = argparse.ArgumentParser() |
| | | parser.add_argument('--model-name', type=str, required=True) |
| | | # parser.add_argument('--model-name', type=str, required=True) |
| | | parser.add_argument('--model-name', type=str, action="append", required=True, default=[]) |
| | | parser.add_argument('--export-dir', type=str, required=True) |
| | | parser.add_argument('--type', type=str, default='onnx', help='["onnx", "torch"]') |
| | | parser.add_argument('--device', type=str, default='cpu', help='["cpu", "cuda"]') |
| | |
| | | calib_num=args.calib_num, |
| | | model_revision=args.model_revision, |
| | | ) |
| | | export_model.export(args.model_name) |
| | | for model_name in args.model_name: |
| | | print("export model: {}".format(model_name)) |
| | | export_model.export(model_name) |