| | |
| | | if ".bin" in model_name: |
| | | model_name_pth = os.path.join(model_dir, model_name.replace('.bin', '.pb')) |
| | | else: |
| | | model_name_pth = os.path.join(model_dir, "{}.pth".format(model_name)) |
| | | model_name_pth = os.path.join(model_dir, "{}.pb".format(model_name)) |
| | | if os.path.exists(model_name_pth): |
| | | logging.info("model_file is load from pth: {}".format(model_name_pth)) |
| | | model_dict = torch.load(model_name_pth, map_location=device) |
| | |
| | | cls, train: bool = True, inference: bool = False |
| | | ) -> Tuple[str, ...]: |
| | | if not inference: |
| | | retval = ("speech", "profile", "binary_labels") |
| | | retval = ("speech", ) |
| | | else: |
| | | # Recognition mode |
| | | retval = ("speech") |
| | | retval = ("speech", ) |
| | | return retval |
| | | |
| | | @classmethod |