From 5023dd04224eddd4c9a047bd946695c3932743ae Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 15 三月 2024 16:24:29 +0800
Subject: [PATCH] Dev gzf llm (#1503)

---
 funasr/train_utils/trainer_llm.py                                |  462 ++++++++++++++++++++++++++++++++++++++
 funasr/models/llm_asr_nar/model.py                               |   29 +
 funasr/datasets/audio_datasets/jsonl2scp.py                      |   62 +++++
 examples/industrial_data_pretraining/paraformer/demo.py          |    8 
 funasr/auto/auto_model.py                                        |    2 
 examples/industrial_data_pretraining/whisper/demo.py             |   10 
 examples/industrial_data_pretraining/whisper/demo_from_openai.py |    3 
 funasr/datasets/llm_datasets_vicuna/samplers.py                  |    4 
 funasr/bin/train_llm.py                                          |  140 ++++++-----
 9 files changed, 639 insertions(+), 81 deletions(-)

diff --git a/examples/industrial_data_pretraining/paraformer/demo.py b/examples/industrial_data_pretraining/paraformer/demo.py
index 0265b12..651df1e 100644
--- a/examples/industrial_data_pretraining/paraformer/demo.py
+++ b/examples/industrial_data_pretraining/paraformer/demo.py
@@ -7,10 +7,10 @@
 
 model = AutoModel(model="iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch", 
                   model_revision="v2.0.4",
-                  vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
-                  vad_model_revision="v2.0.4",
-                  punc_model="iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
-                  punc_model_revision="v2.0.4",
+                  # vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
+                  # vad_model_revision="v2.0.4",
+                  # punc_model="iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
+                  # punc_model_revision="v2.0.4",
                   # spk_model="iic/speech_campplus_sv_zh-cn_16k-common",
                   # spk_model_revision="v2.0.2",
                   )
diff --git a/examples/industrial_data_pretraining/whisper/demo.py b/examples/industrial_data_pretraining/whisper/demo.py
index 01e125d..ddebbdf 100644
--- a/examples/industrial_data_pretraining/whisper/demo.py
+++ b/examples/industrial_data_pretraining/whisper/demo.py
@@ -8,8 +8,14 @@
 from funasr import AutoModel
 
 model = AutoModel(model="iic/Whisper-large-v3",
-                  model_revision="v2.0.4",
+                  model_revision="v2.0.5",
+                  vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
                   )
 
-res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav", language=None)
+res = model.generate(
+	language=None,
+	task="transcribe",
+	batch_size_s=0,
+	input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav")
+
 print(res)
diff --git a/examples/industrial_data_pretraining/whisper/demo_from_openai.py b/examples/industrial_data_pretraining/whisper/demo_from_openai.py
index 2ee8ad5..5cac06b 100644
--- a/examples/industrial_data_pretraining/whisper/demo_from_openai.py
+++ b/examples/industrial_data_pretraining/whisper/demo_from_openai.py
@@ -10,10 +10,11 @@
 # model = AutoModel(model="Whisper-small", hub="openai")
 # model = AutoModel(model="Whisper-medium", hub="openai")
 # model = AutoModel(model="Whisper-large-v2", hub="openai")
-model = AutoModel(model="Whisper-large-v3", hub="openai")
+model = AutoModel(model="Whisper-large-v3", hub="openai", vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",)
 
 res = model.generate(
 	language=None,
 	task="transcribe",
+	batch_size_s=0,
 	input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav")
 print(res)
diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py
index 2df1910..8c847c5 100644
--- a/funasr/auto/auto_model.py
+++ b/funasr/auto/auto_model.py
@@ -291,7 +291,7 @@
         # step.2 compute asr model
         model = self.model
         deep_update(kwargs, cfg)
-        batch_size = int(kwargs.get("batch_size_s", 300))*1000
+        batch_size = max(int(kwargs.get("batch_size_s", 300))*1000, 1)
         batch_size_threshold_ms = int(kwargs.get("batch_size_threshold_s", 60))*1000
         kwargs["batch_size"] = batch_size
 
diff --git a/funasr/bin/train_llm.py b/funasr/bin/train_llm.py
index a33cd53..8742bf1 100644
--- a/funasr/bin/train_llm.py
+++ b/funasr/bin/train_llm.py
@@ -6,17 +6,22 @@
 import torch
 import hydra
 import logging
+import time
 import argparse
 from io import BytesIO
+
 import torch.distributed as dist
 from collections.abc import Sequence
 from omegaconf import DictConfig, OmegaConf
+from torch.cuda.amp import autocast, GradScaler
 from torch.nn.parallel import DistributedDataParallel as DDP
 from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
+from funasr.train_utils.average_nbest_models import average_checkpoints
 
 from funasr.register import tables
 from funasr.optimizers import optim_classes
-from funasr.train_utils.trainer import Trainer
+from funasr.train_utils.trainer_llm import Trainer
 from funasr.schedulers import scheduler_classes
 from funasr.train_utils.initialize import initialize
 from funasr.download.download_from_hub import download_model
@@ -61,14 +66,9 @@
         dist.init_process_group(backend=kwargs.get("backend", "nccl"), init_method='env://')
         torch.cuda.set_device(local_rank)
         
-    device = kwargs.get("device", "cpu")
+    device = kwargs.get("device", "cuda")
     kwargs["device"] = "cpu"
     model = AutoModel(**kwargs)
-    kwargs["device"] = device
-    model = model.model
-    tokenizer = kwargs["tokenizer"]
-    frontend = kwargs["frontend"]
-    
     
     
     # save config.yaml
@@ -77,35 +77,14 @@
         yaml_file = os.path.join(kwargs.get("output_dir", "./"), "config.yaml")
         OmegaConf.save(config=kwargs, f=yaml_file)
         logging.info("config.yaml is saved to: %s", yaml_file)
-
-
     
-    
-
-    # init_param
-    init_param = kwargs.get("init_param", None)
-    if init_param is not None:
-        if not isinstance(init_param, (list, tuple)):
-            init_param = (init_param,)
-        logging.info("init_param is not None: %s", init_param)
-        for p in init_param:
-            if os.path.exists(p):
-                logging.info(f"Loading pretrained params from {p}")
-                load_pretrained_model(
-                    model=model,
-                    path=p,
-                    ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True),
-                    oss_bucket=kwargs.get("oss_bucket", None),
-                    scope_map=kwargs.get("scope_map", []),
-                    excludes=kwargs.get("excludes", None),
-                )
-            else:
-                logging.info(f"Checkpoint does not exist, init randomly: {p}")
-    elif kwargs.get("init", None):
-        initialize(model, kwargs.get("init", "kaiming_normal"))
-    else:
-        print("No initialize method")
-
+    # parse kwargs
+    kwargs = model.kwargs
+    kwargs["device"] = device
+    tokenizer = kwargs["tokenizer"]
+    frontend = kwargs["frontend"]
+    model = model.model
+    del kwargs["model"]
 
     # freeze_param
     freeze_param = kwargs.get("freeze_param", None)
@@ -129,7 +108,8 @@
         model = FSDP(model).cuda(local_rank)
     else:
         model = model.to(device=kwargs.get("device", "cuda"))
-        
+
+    kwargs["device"] = next(model.parameters()).device
         
     # optim
     optim = kwargs.get("optim", "adam")
@@ -156,34 +136,68 @@
         batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler)
         batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf"))
         batch_sampler_val = batch_sampler_class(dataset_val, is_training=False, **kwargs.get("dataset_conf"))
-    dataloader_tr = torch.utils.data.DataLoader(dataset_tr,
-                                                collate_fn=dataset_tr.collator,
-                                                batch_sampler=batch_sampler,
-                                                num_workers=kwargs.get("dataset_conf").get("num_workers", 4),
-                                                pin_memory=True)
+        
+    dataloader_tr = torch.utils.data.DataLoader(dataset_tr, collate_fn=dataset_tr.collator, **batch_sampler)
+    dataloader_val = torch.utils.data.DataLoader(dataset_val, collate_fn=dataset_val.collator, **batch_sampler_val)
+
+    trainer = Trainer(local_rank=local_rank,
+                      use_ddp=use_ddp,
+                      resume=kwargs.get("resume", True),
+                      device=kwargs["device"],
+                      **kwargs.get("train_conf"),
+                      )
+
+    scaler = GradScaler(enabled=trainer.use_fp16) if trainer.use_fp16 else None
+    scaler = ShardedGradScaler(enabled=trainer.use_fp16) if trainer.use_fsdp else scaler
+
+    trainer.resume_checkpoint(model=model, optim=optim, scheduler=scheduler, scaler=scaler)
+
+    tensorboard_dir = os.path.join(kwargs.get("output_dir"), "tensorboard")
+    os.makedirs(tensorboard_dir, exist_ok=True)
+    try:
+        from tensorboardX import SummaryWriter
+        writer = SummaryWriter(tensorboard_dir) if trainer.rank == 0 else None
+    except:
+        writer = None
     
-    dataloader_val = torch.utils.data.DataLoader(dataset_val,
-                                                collate_fn=dataset_val.collator,
-                                                batch_sampler=batch_sampler_val,
-                                                num_workers=kwargs.get("dataset_conf").get("num_workers", 4),
-                                                pin_memory=True)
-    trainer = Trainer(
-        model=model,
-        optim=optim,
-        scheduler=scheduler,
-        dataloader_train=dataloader_tr,
-        dataloader_val=dataloader_val,
-        local_rank=local_rank,
-        use_ddp=use_ddp,
-        use_fsdp=use_fsdp,
-        output_dir=kwargs.get("output_dir", "./exp"),
-        resume=kwargs.get("resume", True),
-        **kwargs.get("train_conf"),
-    )
-    trainer.run()
-    
-    if use_ddp or use_fsdp:
-        torch.distributed.destroy_process_group()
+    for epoch in range(trainer.start_epoch, trainer.max_epoch + 1):
+        time1 = time.perf_counter()
+        trainer.train_epoch(
+                            model=model,
+                            optim=optim,
+                            scheduler=scheduler,
+                            scaler=scaler,
+                            dataloader_train=dataloader_tr,
+                            dataloader_val=dataloader_val,
+                            epoch=epoch,
+                            writer=writer
+                            )
+
+        trainer.validate_epoch(
+            model=model,
+            dataloader_val=dataloader_val,
+            epoch=epoch,
+            writer=writer
+        )
+
+        trainer.save_checkpoint(epoch, model=model, optim=optim, scheduler=scheduler, scaler=scaler)
+
+        scheduler.step()
+
+        time2 = time.perf_counter()
+        time_escaped = (time2 - time1) / 3600.0
+        logging.info(
+            f"\nrank: {local_rank}, "
+            f"time_escaped_epoch: {time_escaped:.3f} hours, "
+            f"estimated to finish {trainer.max_epoch} "
+            f"epoch: {(trainer.max_epoch - epoch) * time_escaped:.3f} hours\n")
+
+
+    if trainer.rank == 0:
+        average_checkpoints(trainer.output_dir, trainer.avg_nbest_model)
+
+    trainer.close()
+
 
     
 
diff --git a/funasr/datasets/audio_datasets/jsonl2scp.py b/funasr/datasets/audio_datasets/jsonl2scp.py
new file mode 100644
index 0000000..9a2b023
--- /dev/null
+++ b/funasr/datasets/audio_datasets/jsonl2scp.py
@@ -0,0 +1,62 @@
+import os
+import json
+import torch
+import logging
+import hydra
+from omegaconf import DictConfig, OmegaConf
+import concurrent.futures
+import librosa
+import torch.distributed as dist
+
+
+
+def gen_scp_from_jsonl(jsonl_file, data_type_list, wav_scp_file, text_file):
+
+    wav_f = open(wav_scp_file, "w")
+    text_f = open(text_file, "w")
+    with open(jsonl_file, encoding='utf-8') as fin:
+        for line in fin:
+            data = json.loads(line.strip())
+            
+            prompt = data.get("prompt", "<ASR>")
+            source = data[data_type_list[0]]
+            target = data[data_type_list[1]]
+            source_len = data.get("source_len", 1)
+            target_len = data.get("target_len", 0)
+            if "aishell" in source:
+                target = target.replace(" ", "")
+            key = data["key"]
+            wav_f.write(f"{key}\t{source}\n")
+            wav_f.flush()
+            text_f.write(f"{key}\t{target}\n")
+            text_f.flush()
+
+    wav_f.close()
+    text_f.close()
+    
+    
+                    
+@hydra.main(config_name=None, version_base=None)
+def main_hydra(cfg: DictConfig):
+ 
+    kwargs = OmegaConf.to_container(cfg, resolve=True)
+
+    scp_file_list = kwargs.get("scp_file_list", ("/Users/zhifu/funasr1.0/test_local/wav.scp", "/Users/zhifu/funasr1.0/test_local/text.txt"))
+    if isinstance(scp_file_list, str):
+        scp_file_list = eval(scp_file_list)
+    data_type_list = kwargs.get("data_type_list", ("source", "target"))
+    jsonl_file = kwargs.get("jsonl_file_in", "/Users/zhifu/funasr1.0/test_local/audio_datasets.jsonl")
+    gen_scp_from_jsonl(jsonl_file, data_type_list, *scp_file_list)
+    
+
+"""
+python -m funasr.datasets.audio_datasets.json2scp \
+++scp_file_list='["/Users/zhifu/funasr1.0/test_local/wav.scp", "/Users/zhifu/funasr1.0/test_local/text.txt"]' \
+++data_type_list='["source", "target"]' \
+++jsonl_file_in=/Users/zhifu/funasr1.0/test_local/audio_datasets.jsonl
+"""
+
+if __name__ == "__main__":
+    main_hydra()
+
+    
\ No newline at end of file
diff --git a/funasr/datasets/llm_datasets_vicuna/samplers.py b/funasr/datasets/llm_datasets_vicuna/samplers.py
index fe840e2..c728d9c 100644
--- a/funasr/datasets/llm_datasets_vicuna/samplers.py
+++ b/funasr/datasets/llm_datasets_vicuna/samplers.py
@@ -142,9 +142,9 @@
     def set_epoch(self, epoch):
         self.epoch = epoch
 
-
+@tables.register("batch_sampler_classes", "CustomDistributedBatchSampler_fn")
 def CustomDistributedBatchSampler_fn(dataset, **kwargs):
-    dataloader_args = {"dataset": dataset}
+    dataloader_args = {}
     dataloader_args["batch_sampler"] = CustomDistributedBatchSampler(dataset, **kwargs)
     dataloader_args["num_workers"] = kwargs.get("num_workers", 4)
     dataloader_args["pin_memory"] = kwargs.get("pin_memory", True)
diff --git a/funasr/models/llm_asr_nar/model.py b/funasr/models/llm_asr_nar/model.py
index a6096b2..06b2193 100644
--- a/funasr/models/llm_asr_nar/model.py
+++ b/funasr/models/llm_asr_nar/model.py
@@ -264,7 +264,7 @@
             audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000),
                                                             data_type=kwargs.get("data_type", "sound"),
                                                             tokenizer=None)
-            if len(kwargs.get("data_type")) > 1:
+            if len(kwargs.get("data_type", [])) > 1:
                 audio_sample_list, text_token_int_list = audio_sample_list
                 text_token_int = text_token_int_list[0].replace(" ", "")
                 text_token_int = tokenizer.encode(text_token_int)
@@ -561,7 +561,7 @@
         audio_mask = kwargs.get("audio_mask", None)
         audio_token_lengths = audio_mask.sum(-1) if audio_mask is not None else None
         text_token_int = kwargs.get("text_token_int", None)
-        if audio_token_lengths is None:
+        if audio_token_lengths is None and text_token_int is not None:
             audio_token_lengths = torch.tensor([len(text_token_int)], dtype=torch.int64)
         
         batch = {"speech": speech, "speech_lengths": speech_lengths}
@@ -572,7 +572,9 @@
                                                                                        mask=enc_mask,
                                                                                        target_label_length=audio_token_lengths,
                                                                                        )
-            loss_pre = self.criterion_pre(audio_token_lengths.type_as(pre_token_length), pre_token_length)
+            loss_pre = 0.0
+            if audio_token_lengths is not None:
+                loss_pre = self.criterion_pre(audio_token_lengths.type_as(pre_token_length), pre_token_length)
         
         return pre_acoustic_embeds, pre_token_length, loss_pre
     
@@ -603,10 +605,12 @@
             audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000),
                                                             data_type=kwargs.get("data_type", "sound"),
                                                             tokenizer=None)
-            if len(kwargs.get("data_type")) > 1:
+            if len(kwargs.get("data_type", [])) > 1:
                 audio_sample_list, text_token_int_list = audio_sample_list
-                text_token_int = text_token_int_list[0].replace(" ", "")
+                text_token_int = text_token_int_list[0]
                 text_token_int = tokenizer.encode(text_token_int)
+                if text_token_int[0] == tokenizer.bos_token_id:
+                    text_token_int = text_token_int[1:]
             else:
                 text_token_int = None
             time2 = time.perf_counter()
@@ -621,24 +625,30 @@
         speech_lengths = speech_lengths.to(device=kwargs["device"])
         
         # Encoder
-        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, text_token_int=text_token_int)
+        res = self.encode(speech, speech_lengths, text_token_int=text_token_int)
+        encoder_out = res[0]
         
         # adaptor
         encoder_out = self.adaptor(encoder_out)
         
         prompt_pre = "USER: \nINSTRUCTION: {}\nINPUT: ".format(prompt)
         prompt_ids = tokenizer.encode(prompt_pre)
+        if prompt_ids[0] == tokenizer.bos_token_id:
+            prompt_ids = prompt_ids[1:]
+        # prompt_ids = prompt_ids + [tokenizer.pad_token_id]
         prompt_length = len(prompt_ids)
         prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64).to(kwargs["device"])
+        pad = torch.tensor([tokenizer.pad_token_id], dtype=torch.int64).to(kwargs["device"])
         
         if hasattr(self.llm.model, "embed_tokens"):
             inputs_embeds = self.llm.model.embed_tokens(prompt_ids)
+            pad = self.llm.model.embed_tokens(pad)
         elif hasattr(self.llm.model.model, "embed_tokens"):
             inputs_embeds = self.llm.model.model.embed_tokens(prompt_ids)
         else:
             inputs_embeds = self.llm.model.model.model.embed_tokens(prompt_ids)
         
-        inputs_embeds = torch.cat((inputs_embeds[None, :, :], encoder_out), dim=1)  # [prompt, audio]
+        inputs_embeds = torch.cat((inputs_embeds[None, :, :], encoder_out, pad[None, :, :]), dim=1)  # [prompt, audio]
         attention_mask = torch.ones(inputs_embeds.size()[:-1], dtype=torch.long).to(kwargs["device"])
         
         # model_outputs = self.llm.generate(
@@ -662,8 +672,11 @@
         preds = torch.argmax(model_outputs.logits, -1)
         text = tokenizer.batch_decode(preds, add_special_tokens=False, skip_special_tokens=True)
         
-        text = text[0].split(': ')[-1]
+        text = text[0].split(':')[-1]
         text = text.strip()
+        if text.startswith("Please\n "):
+            text = text.replace("Please\n ", "")
+            text = text.strip()
         
         # preds = torch.argmax(model_outputs.logits, -1)
         
diff --git a/funasr/train_utils/trainer_llm.py b/funasr/train_utils/trainer_llm.py
new file mode 100644
index 0000000..6a3b83b
--- /dev/null
+++ b/funasr/train_utils/trainer_llm.py
@@ -0,0 +1,462 @@
+import os
+import time
+import torch
+import logging
+from tqdm import tqdm
+from datetime import datetime
+import torch.distributed as dist
+from torch.cuda.amp import autocast, GradScaler
+from contextlib import nullcontext, contextmanager
+from pathlib import Path
+
+from funasr.train_utils.device_funcs import to_device
+from funasr.train_utils.recursive_op import recursive_average
+from funasr.train_utils.average_nbest_models import average_checkpoints
+from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
+
+@contextmanager
+def maybe_autocast(enabled):
+    if enabled:
+        with autocast():
+            yield
+    else:
+        yield
+
+class Trainer:
+    """
+    A simple trainer class for training a PyTorch model, saving checkpoints at the end of each epoch,
+    and optionally resuming from a saved checkpoint.
+
+    Attributes:
+        max_epoch (int): Maximum number of epochs for training.
+        model (torch.nn.Module): The model to be trained.
+        optim (torch.optim.Optimizer): The optimizer to use for training.
+        scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler.
+        dataloader_train (torch.utils.data.DataLoader): DataLoader for the training dataset.
+        dataloader_val (torch.utils.data.DataLoader): DataLoader for the validation dataset.
+        output_dir (str): Directory where model checkpoints will be saved.
+        resume (str, optional): Path to a checkpoint to resume training from.
+    """
+    
+    def __init__(self,
+                 local_rank,
+                 use_ddp: bool = False,
+                 use_fsdp: bool = False,
+                 use_fp16: bool = False,
+                 output_dir: str="./",
+                 **kwargs):
+        """
+        Initializes the Trainer class with the model, optimizer, scheduler, dataloaders, and other settings.
+
+        Args:
+            model (torch.nn.Module): The model to be trained.
+            optim (torch.optim.Optimizer): The optimizer to use for training.
+            scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler.
+            dataloader_train (torch.utils.data.DataLoader): The DataLoader for the training dataset.
+            dataloader_val (torch.utils.data.DataLoader): The DataLoader for the validation dataset.
+            **kwargs: Additional keyword arguments:
+                      max_epoch (int): The maximum number of epochs for training.
+                      output_dir (str): The directory where model checkpoints will be saved. Default is './'.
+                      resume (str, optional): The file path to a checkpoint to resume training from.
+        """
+        
+        self.output_dir = output_dir
+        self.resume = kwargs.get('resume', True)
+        self.start_epoch = 0
+        self.max_epoch = kwargs.get('max_epoch', 100)
+        self.local_rank = local_rank
+        self.use_ddp = use_ddp
+        self.use_fsdp = use_fsdp
+        self.device = kwargs.get('device', "cuda")
+        self.avg_nbest_model = kwargs.get("avg_nbest_model", 5)
+        # self.kwargs = kwargs
+        self.log_interval = kwargs.get("log_interval", 50)
+        self.batch_total = 0
+        self.use_fp16 = use_fp16
+        self.disable_gpu_cache = kwargs.get("disable_gpu_cache", True)
+        # scaler = GradScaler(enabled=use_fp16) if use_fp16 else None
+        # scaler = ShardedGradScaler(enabled=use_fp16) if use_fsdp else scaler
+        # self.scaler = scaler
+        self.save_checkpoint_interval = kwargs.get("save_checkpoint_interval", 5000)
+        self.accum_grad = kwargs.get("accum_grad", 1)
+        self.grad_clip = kwargs.get("grad_clip", 10.0)
+        self.grad_clip_type = kwargs.get("grad_clip_type", 2.0)
+        self.validate_interval = kwargs.get("validate_interval", 5000)
+        
+    
+        try:
+            rank = dist.get_rank()
+            world_size = dist.get_world_size()
+        except:
+            rank = 0
+            world_size = 1
+            logging.warning("distributed is not initialized, only single shard")
+        self.rank = rank
+        self.world_size = world_size
+        
+
+        
+    
+    def save_checkpoint(self, epoch,
+                        step=None,
+                        model=None,
+                        optim=None,
+                        scheduler=None,
+                        scaler=None,
+                        ):
+        """
+        Saves a checkpoint containing the model's state, the optimizer's state,
+        and the scheduler's state at the end of the given epoch. This method is
+        intended to be called at the end of each epoch to save the training progress.
+
+        Args:
+            epoch (int): The epoch number at which the checkpoint is being saved.
+        """
+        if self.rank == 0:
+            state = {
+                'epoch': epoch,
+                'state_dict': model.state_dict(),
+                'optimizer': optim.state_dict(),
+                'scheduler': scheduler.state_dict(),
+            }
+            if scaler:
+                state["scaler_state"] = scaler.state_dict()
+            # Create output directory if it does not exist
+            os.makedirs(self.output_dir, exist_ok=True)
+            if step is None:
+                filename = os.path.join(self.output_dir, f'model.pt.ep{epoch}')
+            else:
+                filename = os.path.join(self.output_dir, f'model.pt.ep{epoch}.{step}')
+            
+            torch.save(state, filename)
+            
+            print(f'\nCheckpoint saved to {filename}\n')
+            latest = Path(os.path.join(self.output_dir, f'model.pt'))
+            torch.save(state, latest)
+        
+        if self.use_ddp or self.use_fsdp:
+            dist.barrier()
+    
+    def resume_checkpoint(self,
+                          model=None,
+                          optim=None,
+                          scheduler=None,
+                          scaler=None,
+                          ):
+        """
+        Resumes training from a checkpoint at the given file path.
+        Loads the model's state, the optimizer's state, and the scheduler's state.
+
+        Args:
+            resume_path (str): The file path to the checkpoint to resume from.
+        """
+        if self.resume:
+            ckpt = os.path.join(self.output_dir, "model.pt")
+            if os.path.isfile(ckpt):
+                checkpoint = torch.load(ckpt)
+                self.start_epoch = checkpoint['epoch'] + 1
+                # self.model.load_state_dict(checkpoint['state_dict'])
+                src_state = checkpoint['state_dict']
+                dst_state = model.state_dict()
+                for k in dst_state.keys():
+                    if not k.startswith("module.") and "module."+k in src_state.keys():
+                        k_ddp = "module."+k
+                    else:
+                        k_ddp = k
+                    if k_ddp in src_state.keys():
+                        dst_state[k] = src_state[k_ddp]
+                    else:
+                        print(f"Miss key in ckpt: model: {k}, ckpt: {k_ddp}")
+    
+                model.load_state_dict(dst_state)
+                optim.load_state_dict(checkpoint['optimizer'])
+                scheduler.load_state_dict(checkpoint['scheduler'])
+                if scaler is not None and 'scaler_state' in checkpoint:
+                    scaler.load_state_dict(checkpoint['scaler_state'])
+                print(f"Checkpoint loaded successfully from '{ckpt}'")
+            else:
+                print(f"No checkpoint found at '{ckpt}', does not resume status!")
+    
+        if self.use_ddp or self.use_fsdp:
+            dist.barrier()
+        
+    # def train(self):
+    #     """
+    #     Starts the training process, iterating over epochs, training the model,
+    #     and saving checkpoints at the end of each epoch.
+    #     """
+    #     if self.resume:
+    #         self.resume_checkpoint(self.output_dir)
+    #
+    #     for epoch in range(self.start_epoch, self.max_epoch + 1):
+    #         time1 = time.perf_counter()
+    #         self.train_epoch(epoch)
+    #
+    #
+    #
+    #         if self.use_ddp or self.use_fsdp:
+    #             dist.barrier()
+    #
+    #         self._validate_epoch(epoch)
+    #
+    #         if self.use_ddp or self.use_fsdp:
+    #             dist.barrier()
+    #
+    #
+    #         if self.rank == 0:
+    #             self._save_checkpoint(epoch)
+    #
+    #         if self.use_ddp or self.use_fsdp:
+    #             dist.barrier()
+    #
+    #         self.scheduler.step()
+    #
+    #         time2 = time.perf_counter()
+    #         time_escaped = (time2 - time1)/3600.0
+    #         print(f"\nrank: {self.local_rank}, time_escaped_epoch: {time_escaped:.3f} hours, estimated to finish {self.max_epoch} epoch: {(self.max_epoch-epoch)*time_escaped:.3f} hours\n")
+    #
+    #     if self.rank == 0:
+    #         average_checkpoints(self.output_dir, self.avg_nbest_model)
+    #
+    #     if self.use_ddp or self.use_fsdp:
+    #         dist.barrier()
+    #
+    #
+    #     if writer:
+    #         writer.close()
+    #
+    
+    def train_epoch(self,
+                model=None,
+                optim=None,
+                scheduler=None,
+                scaler=None,
+                dataloader_train=None,
+                dataloader_val=None,
+                epoch=None,
+                writer=None,
+                    ):
+        """
+        Defines the training process for a single epoch with gradient accumulation.
+        Args:
+            epoch (int): The current epoch number.
+        """
+        model.train()
+
+        
+        # Set the number of steps for gradient accumulation
+        accum_grad = self.accum_grad
+        # Initialize the gradient accumulation
+        optim.zero_grad()
+        speed_stats = {}
+        time5 = time.perf_counter()
+        
+        for batch_idx, batch in enumerate(dataloader_train):
+            self.batch_total += 1
+            time1 = time.perf_counter()
+            speed_stats["data_load"] = f"{time1-time5:0.3f}"
+
+            batch = to_device(batch, self.device)
+            
+            my_context = model.no_sync if batch_idx % accum_grad != 0 else nullcontext
+            with my_context():
+                time2 = time.perf_counter()
+                with maybe_autocast(self.use_fp16):
+                    retval = model(**batch)
+                    
+                if self.disable_gpu_cache: torch.cuda.empty_cache()
+
+                time3 = time.perf_counter()
+                speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
+                loss, stats, weight = retval
+                stats = {k: v for k, v in stats.items() if v is not None}
+                if self.use_ddp or self.use_fsdp:
+                    # Apply weighted averaging for loss and stats
+                    loss = (loss * weight.type(loss.dtype)).sum()
+                    # if distributed, this method can also apply all_reduce()
+                    stats, weight = recursive_average(stats, weight, distributed=True)
+                    # Now weight is summation over all workers
+                    loss /= weight
+                    # Multiply world_size because DistributedDataParallel
+                    # automatically normalizes the gradient by world_size.
+                    loss *= self.world_size
+                # Scale the loss since we're not updating for every mini-batch
+                loss = loss / accum_grad
+                if self.use_fp16:
+                    scaler.scale(loss).backward()
+                else:
+                    loss.backward()
+                time4 = time.perf_counter()
+                speed_stats["backward_time"] = f"{time4 - time3:0.3f}"
+            
+            # Perform an optimizer step only after accumulating enough gradients
+            if (batch_idx + 1) % accum_grad == 0:
+                # Perform gradient clipping if it is set
+                if self.grad_clip > 0:
+                    grad_norm = torch.nn.utils.clip_grad_norm_(
+                        model.parameters(),
+                        max_norm=self.grad_clip,
+                        norm_type=self.grad_clip_type,
+                    )
+                    if not torch.isfinite(grad_norm):
+                        logging.warning(
+                            f"The grad norm is {grad_norm}. Skipping updating the model."
+                        )
+                        optim.zero_grad()  # Reset gradients
+                        continue
+                
+                # Execute an optimization step (update model parameters)
+                if self.use_ddp or self.use_fsdp:
+                    dist.barrier()
+                if self.use_fp16:
+                    scaler.step(optim)
+                    scaler.update()
+                else:
+                    optim.step()
+                scheduler.step()
+                # Clear gradients for the next accumulation stage
+                optim.zero_grad(set_to_none=True)
+                total_time = f"{time.perf_counter() - time5:0.3f}"
+                time5 = time.perf_counter()
+                speed_stats["optim_time"] = f"{time5 - time4:0.3f}"
+    
+                speed_stats["total_time"] = total_time
+                lr = scheduler.get_last_lr()[0]
+
+                self.log(epoch, batch_idx,
+                         batch_num_epoch=len(dataloader_train),
+                         lr=lr,
+                         loss=loss.detach().cpu().item(),
+                         speed_stats=speed_stats,
+                         stats=stats,
+                         writer=writer,
+                         tag="train",
+                         )
+
+            if (batch_idx + 1) % self.validate_interval == 0:
+                self.validate_epoch(
+                    model=model,
+                    dataloader_val=dataloader_val,
+                    epoch=epoch,
+                    writer=writer
+                )
+
+            if (batch_idx+1) % self.save_checkpoint_interval == 0 and self.rank == 0:
+                self.save_checkpoint(epoch, model=model, optim=optim, scheduler=scheduler, scaler=scaler, step=batch_idx+1)
+
+        
+        if self.use_ddp or self.use_fsdp:
+            dist.barrier()
+        
+        
+
+    def validate_epoch(self,
+                       model=None,
+                       dataloader_val=None,
+                       epoch=None,
+                       writer=None,
+                       **kwargs,
+                       ):
+        """
+        Defines the validation process for a single epoch.
+        Should be implemented with the actual model validation steps.
+    
+        Args:
+            epoch (int): The current epoch number.
+        """
+        model.eval()
+        
+        with torch.no_grad():
+            
+            speed_stats = {}
+            time5 = time.perf_counter()
+            for batch_idx, batch in enumerate(dataloader_val):
+                time1 = time.perf_counter()
+                speed_stats["data_load"] = f"{time1 - time5:0.3f}"
+                batch = to_device(batch, self.device)
+                time2 = time.perf_counter()
+                retval = model(**batch)
+                time3 = time.perf_counter()
+                speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
+                loss, stats, weight = retval
+                stats = {k: v for k, v in stats.items() if v is not None}
+                if self.use_ddp or self.use_fsdp:
+                    # Apply weighted averaging for loss and stats
+                    loss = (loss * weight.type(loss.dtype)).sum()
+                    # if distributed, this method can also apply all_reduce()
+                    stats, weight = recursive_average(stats, weight, distributed=True)
+                    # Now weight is summation over all workers
+                    loss /= weight
+                    # Multiply world_size because DistributedDataParallel
+                    # automatically normalizes the gradient by world_size.
+                    loss *= self.world_size
+                # Scale the loss since we're not updating for every mini-batch
+                loss = loss
+                time4 = time.perf_counter()
+
+                
+                self.log(epoch, batch_idx,
+                         batch_num_epoch=len(dataloader_val),
+                         lr=0.0,
+                         loss=loss.detach().cpu().item(),
+                         speed_stats=speed_stats,
+                         stats=stats,
+                         writer=writer,
+                         tag="train",
+                         )
+
+        model.train()
+        
+        
+    def log(self,
+            epoch=0,
+            batch_idx=0,
+            batch_num_epoch=-1,
+            lr=0.0,
+            loss=0.0,
+            speed_stats=None,
+            stats=None,
+            writer=None,
+            tag="train",
+            ):
+        
+        if (batch_idx + 1) % self.log_interval == 0:
+            
+            gpu_info = "GPU, memory: {:.3f} GB, " \
+                       "{:.3f} GB, " \
+                       "{:.3f} GB, " \
+                       "{:.3f} GB".format(torch.cuda.memory_allocated() / 1024 / 1024 / 1024,
+                                          torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024,
+                                          torch.cuda.memory_reserved() / 1024 / 1024 / 1024,
+                                          torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024,
+                                          )
+            
+            time_now = datetime.now()
+            time_now = time_now.strftime("%Y-%m-%d %H:%M:%S")
+            description = (
+                f"{time_now}, "
+                f"rank: {self.local_rank}, "
+                f"epoch: {epoch}/{self.max_epoch}, "
+                f"step: {batch_idx + 1}/{batch_num_epoch}, total step: {self.batch_total}, "
+                f"(loss: {loss:.3f}), "
+                f"(lr: {lr:.3e}), "
+                f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}, "
+                f"{speed_stats}, "
+                f"{gpu_info}"
+            )
+            logging.info(description)
+            
+            if writer is not None:
+                writer.add_scalar(f'rank{self.local_rank}_Loss/{tag}', loss, self.batch_total)
+                writer.add_scalar(f'rank{self.local_rank}_lr/{tag}', lr, self.batch_total)
+                for key, var in stats.items():
+                    writer.add_scalar(f'rank{self.local_rank}_{key}/{tag}', var.item(), self.batch_total)
+                for key, var in speed_stats.items():
+                    writer.add_scalar(f'rank{self.local_rank}_{key}/{tag}', eval(var), self.batch_total)
+        
+    def close(self, writer=None):
+        if writer is not None:
+            writer.close()
+    
+        if self.use_ddp or self.use_fsdp:
+            torch.distributed.destroy_process_group()
\ No newline at end of file

--
Gitblit v1.9.1