| | |
| | | default="13_15", |
| | | help="The range of noise decibel level.", |
| | | ) |
| | | parser.add_argument( |
| | | "--batch_interval", |
| | | type=int, |
| | | default=10000, |
| | | help="The batch interval for saving model.", |
| | | ) |
| | | |
| | | for class_choices in cls.class_choices_list: |
| | | # Append --<name> and --<name>_conf. |
| | |
| | | if "model.ckpt-" in model_name or ".bin" in model_name: |
| | | model_name_pth = os.path.join(model_dir, model_name.replace('.bin', |
| | | '.pb')) if ".bin" in model_name else os.path.join( |
| | | model_dir, "{}.pth".format(model_name)) |
| | | model_dir, "{}.pb".format(model_name)) |
| | | if os.path.exists(model_name_pth): |
| | | logging.info("model_file is load from pth: {}".format(model_name_pth)) |
| | | model_dict = torch.load(model_name_pth, map_location=device) |
| | |
| | | if "model.ckpt-" in model_name or ".bin" in model_name: |
| | | model_name_pth = os.path.join(model_dir, model_name.replace('.bin', |
| | | '.pb')) if ".bin" in model_name else os.path.join( |
| | | model_dir, "{}.pth".format(model_name)) |
| | | model_dir, "{}.pb".format(model_name)) |
| | | if os.path.exists(model_name_pth): |
| | | logging.info("model_file is load from pth: {}".format(model_name_pth)) |
| | | model_dict = torch.load(model_name_pth, map_location=device) |