ming030890
2025-07-04 a750595594321833b48dc19798eed66876a100b4
funasr/train_utils/trainer_ds.py
@@ -30,9 +30,8 @@
            yield
    else:
        if dtype == torch.float16 or dtype == torch.bfloat16:
            yield
            # with autocast(enabled=True, dtype=dtype):
            #     yield
            with autocast(enabled=True, dtype=dtype):
                yield
        else:
            yield
@@ -684,7 +683,7 @@
            scaled_loss = model.backward(loss)
        else:
            loss = loss / self.accum_grad
            if self.use_fp16 or self.use_bf16:
            if scaler:
                scaler.scale(loss).backward()
            else:
                loss.backward()
@@ -712,7 +711,7 @@
                # Execute an optimization step (update model parameters)
                if self.use_ddp or self.use_fsdp:
                    dist.barrier()
                if self.use_fp16 or self.use_bf16:
                if scaler:
                    scaler.step(optim)
                    scaler.update()
                else:
@@ -736,6 +735,9 @@
        Args:
            epoch (int): The current epoch number.
        """
        self.val_loss_avg = 0.0
        self.val_acc_avg  = 0.0
        if self.use_ddp or self.use_fsdp or self.use_deepspeed:
            dist.barrier()
        logging.info(f"Validate epoch: {epoch}, rank: {self.rank}\n")
@@ -757,7 +759,7 @@
                    "data_split_i": kwargs.get("data_split_i", 0),
                    "data_split_num": kwargs.get("data_split_num", 1),
                    "log_step": batch_idx + kwargs.get("start_step", 0),
                    "batch_total": batch_idx + 1,
                    "batch_total": self.batch_total,
                    "step_in_epoch": batch_idx + 1,
                    "lr": 0.0,
                }
@@ -883,7 +885,7 @@
            if self.use_wandb and wandb is not None:
                wandb.log(
                    description_dict,
                    setp=batch_total,
                    step=batch_total,
                )
    def close(self, writer=None):