| | |
| | | from funasr.train_utils.device_funcs import to_device |
| | | from funasr.train_utils.recursive_op import recursive_average |
| | | from funasr.train_utils.average_nbest_models import average_checkpoints |
| | | from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler |
| | | |
| | | @contextmanager |
| | | def maybe_autocast(enabled): |
| | |
| | | self.batch_total = 0 |
| | | self.use_fp16 = use_fp16 |
| | | self.disable_gpu_cache = kwargs.get("disable_gpu_cache", True) |
| | | self.scaler = GradScaler(enabled=use_fp16) if use_fp16 else None |
| | | scaler = GradScaler(enabled=use_fp16) if use_fp16 else None |
| | | scaler = ShardedGradScaler(enabled=use_fp16) if use_ddp else scaler |
| | | self.scaler = scaler |
| | | |
| | | |
| | | try: |