| funasr/train_utils/trainer_ds.py | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 | |
| funasr/utils/misc.py | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 |
funasr/train_utils/trainer_ds.py
@@ -15,6 +15,7 @@ from funasr.train_utils.recursive_op import recursive_average from funasr.train_utils.average_nbest_models import average_checkpoints from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler import funasr.utils.misc as misc_utils try: import wandb @@ -268,7 +269,8 @@ filename = os.path.join(self.output_dir, key) logging.info(f"Delete: {filename}") if os.path.exists(filename): os.remove(filename) # os.remove(filename) misc_utils.smart_remove(filename) elif self.use_fsdp: pass @@ -360,7 +362,8 @@ filename = os.path.join(self.output_dir, key) logging.info(f"Delete: {filename}") if os.path.exists(filename): os.remove(filename) # os.remove(filename) misc_utils.smart_remove(filename) if self.use_ddp or self.use_fsdp: dist.barrier() funasr/utils/misc.py
@@ -94,3 +94,26 @@ filename, extension = os.path.splitext(filename_with_extension) # 返回不包含扩展名的文件名 return filename def smart_remove(path): """Intelligently removes files, empty directories, and non-empty directories recursively.""" # Check if the provided path exists if not os.path.exists(path): print(f"{path} does not exist.") return # If the path is a file, delete it if os.path.isfile(path): os.remove(path) print(f"File {path} has been deleted.") # If the path is a directory elif os.path.isdir(path): try: # Attempt to remove an empty directory os.rmdir(path) print(f"Empty directory {path} has been deleted.") except OSError: # If the directory is not empty, remove it along with all its contents shutil.rmtree(path) print(f"Non-empty directory {path} has been recursively deleted.")