游雁
2023-12-06 15868f623089cf70983a8b4f435ff86e7f160b8a
funasr2
2个文件已修改
90 ■■■■ 已修改文件
funasr/cli/train_cli.py 12 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/cli/trainer.py 78 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/cli/train_cli.py
@@ -46,7 +46,7 @@
    
    local_rank = int(os.environ.get('LOCAL_RANK', 0))
    # Check if we are using DDP or FSDP
    use_ddp = 'WORLD_SIZE' in os.environ
    use_ddp = 'WORLD_SIZE' in os.environ and os.environ["WORLD_SIZE"] > 1
    use_fsdp = kwargs.get("use_fsdp", None)
    if use_ddp or use_fsdp:
        dist.init_process_group(backend=kwargs.get("backend", "nccl"), init_method='env://')
@@ -109,7 +109,8 @@
    if use_ddp:
        model = model.cuda(local_rank)
        model = DDP(model, device_ids=[local_rank])
        model = DDP(model, device_ids=[local_rank],
                    find_unused_parameters=kwargs.get("train_conf", {}).get("find_unused_parameters", False))
    elif use_fsdp:
        model = FSDP(model).cuda(local_rank)
    else:
@@ -156,13 +157,6 @@
    if use_ddp or use_fsdp:
        torch.distributed.destroy_process_group()
def train(epoch, model, op):
    pass
def val():
    pass
if __name__ == "__main__":
funasr/cli/trainer.py
@@ -5,6 +5,7 @@
from tqdm import tqdm
from contextlib import nullcontext
import torch.distributed as dist
from funasr.torch_utils.recursive_op import recursive_average
class Trainer:
    """
@@ -56,6 +57,8 @@
        self.start_epoch = 1
        self.max_epoch = kwargs.get('max_epoch', 100)
        self.local_rank = local_rank
        self.rank = dist.get_rank()
        self.world_size = dist.get_world_size()
        self.use_ddp = use_ddp
        self.use_fsdp = use_fsdp
        self.device = torch.device("cuda", local_rank)
@@ -113,7 +116,7 @@
            # self._validate_epoch(epoch)
            if dist.get_rank() == 0:
                self._save_checkpoint(epoch)
            # self.scheduler.step()
            self.scheduler.step()
    
    def _train_epoch(self, epoch):
        """
@@ -126,24 +129,34 @@
                    dynamic_ncols=True)
        
        # Set the number of steps for gradient accumulation
        accumulation_steps = self.kwargs.get("accumulation_steps", 1)
        accum_grad = self.kwargs.get("accum_grad", 1)
        # Initialize the gradient accumulation
        self.optim.zero_grad()
        
        for batch_idx, batch in enumerate(self.dataloader_train):
            batch = to_device(batch, self.device)
            
            my_context = self.model.no_sync if batch_idx % accumulation_steps != 0 else nullcontext
            my_context = self.model.no_sync if batch_idx % accum_grad != 0 else nullcontext
            with my_context():
                retval = self.model(**batch)
                loss, stats, weight = retval
                stats = {k: v for k, v in stats.items() if v is not None}
                if self.use_ddp or self.use_fsdp:
                    # Apply weighted averaging for loss and stats
                    loss = (loss * weight.type(loss.dtype)).sum()
                    # if distributed, this method can also apply all_reduce()
                    stats, weight = recursive_average(stats, weight, distributed=True)
                    # Now weight is summation over all workers
                    loss /= weight
                    # Multiply world_size because DistributedDataParallel
                    # automatically normalizes the gradient by world_size.
                    loss *= self.world_size
                # Scale the loss since we're not updating for every mini-batch
                loss = loss / accumulation_steps
                loss = loss / accum_grad
                loss.backward()
            
            # Perform an optimizer step only after accumulating enough gradients
            if (batch_idx + 1) % accumulation_steps == 0 or (batch_idx + 1) == len(self.dataloader_train):
            if (batch_idx + 1) % accum_grad == 0 or (batch_idx + 1) == len(self.dataloader_train):
                # Perform gradient clipping if it is set
                if self.kwargs.get("grad_clip", None) is not None:
                    grad_norm = torch.nn.utils.clip_grad_norm_(
@@ -171,43 +184,6 @@
            
        pbar.close()
    
    # def _train_epoch(self, epoch):
    #     """
    #     Defines the training process for a single epoch.
    #     Should be implemented with the actual model training steps.
    #
    #     Args:
    #         epoch (int): The current epoch number.
    #     """
    #     self.model.train()
    #     pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch + 1}", total=len(self.dataloader_train), dynamic_ncols=True)
    #     for batch_idx, batch in enumerate(self.dataloader_train):
    #         batch = to_device(batch, "cpu")
    #         retval = self.model(**batch)
    #         loss, stats, weight = retval
    #         self.optim.zero_grad()
    #         loss.backward()
    #
    #         # compute the gradient norm to check if it is normal or not
    #         grad_norm = torch.nn.utils.clip_grad_norm_(
    #             self.model.parameters(),
    #             max_norm=self.kwargs.get("grad_clip", 10.0),
    #             norm_type=self.kwargs.get("grad_clip_type", 2.0),
    #         )
    #         if not torch.isfinite(grad_norm):
    #             logging.warning(
    #                 f"The grad norm is {grad_norm}. Skipping updating the model."
    #             )
    #             continue
    #         self.optim.step()
    #         self.scheduler.step()
    #         pbar.update(1)
    #         pbar.set_description(
    #             f"Training Epoch: {epoch + 1}/{self.max_epoch}, step {batch_idx}/{len(self.dataloader_train)}  (loss: {loss.detach().float()})")
    #
    #     pbar.close()
    #
    def _validate_epoch(self, epoch):
        """
        Defines the validation process for a single epoch.
@@ -221,19 +197,3 @@
            for data, target in self.dataloader_val:
                # Implement the model validation steps here
                pass
# # Example usage
# if __name__ == "__main__":
#     # Assuming the following objects have already been correctly created and initialized:
#     # model, optim, scheduler, dataloader_train, and dataloader_val.
#     trainer = Trainer(
#         max_epoch=10,
#         model=model,
#         optim=optim,
#         scheduler=scheduler,
#         dataloader_train=dataloader_train,
#         dataloader_val=dataloader_val,
#         output_dir='path_to_save_model',
#         resume='path_to_checkpoint_if_any'
#     )
#     trainer.run()