zhifu gao
2023-03-29 d3c6967ca827b78cdd42bd05602c0179b2c84824
funasr/tasks/sv.py
@@ -501,7 +501,7 @@
                if ".bin" in model_name:
                    model_name_pth = os.path.join(model_dir, model_name.replace('.bin', '.pb'))
                else:
                    model_name_pth = os.path.join(model_dir, "{}.pth".format(model_name))
                    model_name_pth = os.path.join(model_dir, "{}.pb".format(model_name))
                if os.path.exists(model_name_pth):
                    logging.info("model_file is load from pth: {}".format(model_name_pth))
                    model_dict = torch.load(model_name_pth, map_location=device)