shixian.shi
2024-03-08 7498bd7388afdde8d5e6f8a4cb6aeb8be8ac60fa
funasr/train_utils/trainer.py
@@ -14,6 +14,7 @@
from funasr.train_utils.device_funcs import to_device
from funasr.train_utils.recursive_op import recursive_average
from funasr.train_utils.average_nbest_models import average_checkpoints
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
@contextmanager
def maybe_autocast(enabled):
@@ -84,7 +85,9 @@
        self.batch_total = 0
        self.use_fp16 = use_fp16
        self.disable_gpu_cache = kwargs.get("disable_gpu_cache", True)
        self.scaler = GradScaler(enabled=use_fp16) if use_fp16 else None
        scaler = GradScaler(enabled=use_fp16) if use_fp16 else None
        scaler = ShardedGradScaler(enabled=use_fp16) if use_ddp else scaler
        self.scaler = scaler
        
    
        try:
@@ -160,7 +163,7 @@
                self.scaler.load_state_dict(checkpoint['scaler_state'])
            print(f"Checkpoint loaded successfully from '{ckpt}'")
        else:
            print(f"No checkpoint found at '{ckpt}', starting from scratch")
            print(f"No checkpoint found at '{ckpt}', does not resume status!")
        if self.use_ddp or self.use_fsdp:
            dist.barrier()