| | |
| | | tokenizer_class = tables.tokenizer_classes.get(tokenizer) |
| | | tokenizer_conf = kwargs.get("tokenizer_conf", {}) |
| | | tokenizer = tokenizer_class(**tokenizer_conf) |
| | | kwargs["tokenizer"] = tokenizer |
| | | |
| | | |
| | | kwargs["token_list"] = tokenizer.token_list if hasattr(tokenizer, "token_list") else None |
| | | kwargs["token_list"] = tokenizer.get_vocab() if hasattr(tokenizer, "get_vocab") else kwargs["token_list"] |
| | | vocab_size = len(kwargs["token_list"]) if kwargs["token_list"] is not None else -1 |
| | | else: |
| | | vocab_size = -1 |
| | | kwargs["tokenizer"] = tokenizer |
| | | |
| | | # build frontend |
| | | frontend = kwargs.get("frontend", None) |
| | | kwargs["input_size"] = None |
| | | if frontend is not None: |
| | | frontend_class = tables.frontend_classes.get(frontend) |
| | | frontend = frontend_class(**kwargs["frontend_conf"]) |
| | | kwargs["frontend"] = frontend |
| | | kwargs["input_size"] = frontend.output_size() if hasattr(frontend, "output_size") else None |
| | | |
| | | kwargs["frontend"] = frontend |
| | | # build model |
| | | model_class = tables.model_classes.get(kwargs["model"]) |
| | | model = model_class(**kwargs, **kwargs.get("model_conf", {}), vocab_size=vocab_size) |
| | |
| | | # f"time_escape_all: {time_escape_total_all_samples:0.3f}") |
| | | return results_ret_list |
| | | |
| | | def export(self, input=None, |
| | | type : str = "onnx", |
| | | quantize: bool = False, |
| | | fallback_num: int = 5, |
| | | calib_num: int = 100, |
| | | opset_version: int = 14, |
| | | **cfg): |
| | | def export(self, input=None, **cfg): |
| | | |
| | | """ |
| | | |
| | | :param input: |
| | | :param type: |
| | | :param quantize: |
| | | :param fallback_num: |
| | | :param calib_num: |
| | | :param opset_version: |
| | | :param cfg: |
| | | :return: |
| | | """ |
| | | |
| | | device = cfg.get("device", "cpu") |
| | | model = self.model.to(device=device) |
| | |
| | | del kwargs["model"] |
| | | model.eval() |
| | | |
| | | batch_size = 1 |
| | | type = kwargs.get("type", "onnx") |
| | | |
| | | key_list, data_list = prepare_data_iterator(input, input_len=None, data_type=kwargs.get("data_type", None), key=None) |
| | | |
| | |
| | | export_dir = export_utils.export_onnx( |
| | | model=model, |
| | | data_in=data_list, |
| | | quantize=quantize, |
| | | fallback_num=fallback_num, |
| | | calib_num=calib_num, |
| | | opset_version=opset_version, |
| | | **kwargs) |
| | | else: |
| | | export_dir = export_utils.export_torchscripts( |
| | | model=model, |
| | | data_in=data_list, |
| | | quantize=quantize, |
| | | fallback_num=fallback_num, |
| | | calib_num=calib_num, |
| | | opset_version=opset_version, |
| | | **kwargs) |
| | | |
| | | return export_dir |