游雁
2024-06-13 c553a8db1712c2a5deeef5bbb68bd1fdf8d61ab7
funasr/train_utils/trainer_ds.py
@@ -29,8 +29,8 @@
        with torch.cuda.amp.autocast(enabled=True, dtype=dtype, cache_enabled=False):
            yield
    else:
        if dtype == torch.float16:
            with autocast(enabled=True):
        if dtype == torch.float16 or dtype == torch.bfloat16:
            with autocast(enabled=True, dtype=dtype):
                yield
        else:
            yield
@@ -60,6 +60,7 @@
        use_ddp: bool = False,
        use_fsdp: bool = False,
        use_fp16: bool = False,
        use_bf16: bool = False,
        use_deepspeed: bool = False,
        output_dir: str = "./",
        **kwargs,
@@ -98,8 +99,11 @@
        self.batch_total = 0
        self.dtype = torch.float32
        self.use_fp16 = use_fp16
        self.use_bf16 = use_bf16
        if self.use_fp16:
            self.dtype = torch.float16
        if self.use_bf16:
            self.dtype = torch.bfloat16
        self.save_checkpoint_interval = kwargs.get("save_checkpoint_interval", 5000)
        self.validate_interval = kwargs.get("validate_interval", 5000)
        self.keep_nbest_models = kwargs.get("keep_nbest_models", 500)
@@ -678,7 +682,7 @@
            scaled_loss = model.backward(loss)
        else:
            loss = loss / self.accum_grad
            if self.use_fp16:
            if self.use_fp16 or self.use_bf16:
                scaler.scale(loss).backward()
            else:
                loss.backward()
@@ -706,7 +710,7 @@
                # Execute an optimization step (update model parameters)
                if self.use_ddp or self.use_fsdp:
                    dist.barrier()
                if self.use_fp16:
                if self.use_fp16 or self.use_bf16:
                    scaler.step(optim)
                    scaler.update()
                else: