liugz18
2024-07-18 d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99
funasr/train_utils/trainer_ds.py
@@ -29,9 +29,10 @@
        with torch.cuda.amp.autocast(enabled=True, dtype=dtype, cache_enabled=False):
            yield
    else:
        if dtype == torch.float16:
            with autocast(enabled=True):
                yield
        if dtype == torch.float16 or dtype == torch.bfloat16:
            yield
            # with autocast(enabled=True, dtype=dtype):
            #     yield
        else:
            yield
@@ -60,6 +61,7 @@
        use_ddp: bool = False,
        use_fsdp: bool = False,
        use_fp16: bool = False,
        use_bf16: bool = False,
        use_deepspeed: bool = False,
        output_dir: str = "./",
        **kwargs,
@@ -78,7 +80,7 @@
                      output_dir (str): The directory where model checkpoints will be saved. Default is './'.
                      resume (str, optional): The file path to a checkpoint to resume training from.
        """
        self.rank = kwargs.get("rank", 0)
        self.rank = rank
        self.local_rank = local_rank
        self.world_size = world_size
        self.use_ddp = use_ddp
@@ -98,8 +100,11 @@
        self.batch_total = 0
        self.dtype = torch.float32
        self.use_fp16 = use_fp16
        self.use_bf16 = use_bf16
        if self.use_fp16:
            self.dtype = torch.float16
        if self.use_bf16:
            self.dtype = torch.bfloat16
        self.save_checkpoint_interval = kwargs.get("save_checkpoint_interval", 5000)
        self.validate_interval = kwargs.get("validate_interval", 5000)
        self.keep_nbest_models = kwargs.get("keep_nbest_models", 500)
@@ -473,7 +478,7 @@
                            for k_ex in self.excludes:
                                k_tmp = k.replace("module.", "")
                                if k_tmp.startswith(k_ex):
                                    logging.info(f"key: {{k}} matching: {k_ex}, excluded")
                                    logging.info(f"key: {k} matching: {k_ex}, excluded")
                                    excludes_flag = True
                                    break
                        if excludes_flag:
@@ -678,7 +683,7 @@
            scaled_loss = model.backward(loss)
        else:
            loss = loss / self.accum_grad
            if self.use_fp16:
            if self.use_fp16 or self.use_bf16:
                scaler.scale(loss).backward()
            else:
                loss.backward()
@@ -706,7 +711,7 @@
                # Execute an optimization step (update model parameters)
                if self.use_ddp or self.use_fsdp:
                    dist.barrier()
                if self.use_fp16:
                if self.use_fp16 or self.use_bf16:
                    scaler.step(optim)
                    scaler.update()
                else: