From 9a9c3b75b5b3359701844a91a9fae6d2979866cd Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 17 一月 2024 18:28:28 +0800
Subject: [PATCH] Funasr1.0 (#1261)

---
 funasr/models/paraformer/model.py                             |    1 
 funasr/models/paraformer/template.yaml                        |    1 
 funasr/train_utils/trainer.py                                 |  109 +++++++++++--
 funasr/bin/train.py                                           |   17 +
 funasr/train_utils/average_nbest_models.py                    |  268 ++++++++++++++++++++-------------
 funasr/auto/auto_model.py                                     |   34 ++-
 examples/industrial_data_pretraining/paraformer/finetune.sh   |    4 
 funasr/datasets/audio_datasets/index_ds.py                    |    6 
 funasr/datasets/audio_datasets/samplers.py                    |    3 
 examples/industrial_data_pretraining/seaco_paraformer/demo.py |    2 
 10 files changed, 298 insertions(+), 147 deletions(-)

diff --git a/examples/industrial_data_pretraining/paraformer/finetune.sh b/examples/industrial_data_pretraining/paraformer/finetune.sh
index 93cce73..7d89876 100644
--- a/examples/industrial_data_pretraining/paraformer/finetune.sh
+++ b/examples/industrial_data_pretraining/paraformer/finetune.sh
@@ -9,9 +9,11 @@
 python funasr/bin/train.py \
 +model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \
 +model_revision="v2.0.2" \
-+train_data_set_list="/Users/zhifu/funasr_github/test_local/aishell2_dev_ios/asr_task_debug_len.jsonl" \
++train_data_set_list="/Users/zhifu/funasr_github/test_local/aishell2_dev_ios/asr_task_debug_len_10.jsonl" \
++valid_data_set_list="/Users/zhifu/funasr_github/test_local/aishell2_dev_ios/asr_task_debug_len_10.jsonl" \
 ++dataset_conf.batch_size=2 \
 ++dataset_conf.batch_type="example" \
+++train_conf.max_epoch=2 \
 +output_dir="outputs/debug/ckpt/funasr2/exp2" \
 +device="cpu" \
 +debug="true"
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/seaco_paraformer/demo.py b/examples/industrial_data_pretraining/seaco_paraformer/demo.py
index 5f17252..19ad1c9 100644
--- a/examples/industrial_data_pretraining/seaco_paraformer/demo.py
+++ b/examples/industrial_data_pretraining/seaco_paraformer/demo.py
@@ -15,6 +15,6 @@
                   spk_model_revision="v2.0.2",
                   )
 
-res = model.generate(input=f"{model.model_path}/example/asr_example.wav",
+res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
                      hotword='杈炬懇闄� 榄旀惌')
 print(res)
\ No newline at end of file
diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py
index 740614c..bedc17d 100644
--- a/funasr/auto/auto_model.py
+++ b/funasr/auto/auto_model.py
@@ -221,7 +221,8 @@
         speed_stats = {}
         asr_result_list = []
         num_samples = len(data_list)
-        pbar = tqdm(colour="blue", total=num_samples+1, dynamic_ncols=True)
+        disable_pbar = kwargs.get("disable_pbar", False)
+        pbar = tqdm(colour="blue", total=num_samples+1, dynamic_ncols=True) if not disable_pbar else None
         time_speech_total = 0.0
         time_escape_total = 0.0
         for beg_idx in range(0, num_samples, batch_size):
@@ -239,8 +240,7 @@
             time2 = time.perf_counter()
             
             asr_result_list.extend(results)
-            pbar.update(1)
-            
+
             # batch_data_time = time_per_frame_s * data_batch_i["speech_lengths"].sum().item()
             batch_data_time = meta_data.get("batch_data_time", -1)
             time_escape = time2 - time1
@@ -252,12 +252,15 @@
             description = (
                 f"{speed_stats}, "
             )
-            pbar.set_description(description)
+            if pbar:
+                pbar.update(1)
+                pbar.set_description(description)
             time_speech_total += batch_data_time
             time_escape_total += time_escape
-            
-        pbar.update(1)
-        pbar.set_description(f"rtf_avg: {time_escape_total/time_speech_total:0.3f}")
+
+        if pbar:
+            pbar.update(1)
+            pbar.set_description(f"rtf_avg: {time_escape_total/time_speech_total:0.3f}")
         torch.cuda.empty_cache()
         return asr_result_list
     
@@ -309,8 +312,11 @@
             time_speech_total_per_sample = speech_lengths/16000
             time_speech_total_all_samples += time_speech_total_per_sample
 
+            pbar_sample = tqdm(colour="blue", total=n + 1, dynamic_ncols=True)
+
             all_segments = []
             for j, _ in enumerate(range(0, n)):
+                pbar_sample.update(1)
                 batch_size_ms_cum += (sorted_data[j][0][1] - sorted_data[j][0][0])
                 if j < n - 1 and (
                     batch_size_ms_cum + sorted_data[j + 1][0][1] - sorted_data[j + 1][0][0]) < batch_size and (
@@ -319,13 +325,14 @@
                 batch_size_ms_cum = 0
                 end_idx = j + 1
                 speech_j, speech_lengths_j = slice_padding_audio_samples(speech, speech_lengths, sorted_data[beg_idx:end_idx])       
-                results = self.inference(speech_j, input_len=None, model=model, kwargs=kwargs, **cfg)
+                results = self.inference(speech_j, input_len=None, model=model, kwargs=kwargs, disable_pbar=True, **cfg)
                 if self.spk_model is not None:
-                    
+
+                  
                     # compose vad segments: [[start_time_sec, end_time_sec, speech], [...]]
                     for _b in range(len(speech_j)):
-                        vad_segments = [[sorted_data[beg_idx:end_idx][_b][0][0]/1000.0, \
-                                        sorted_data[beg_idx:end_idx][_b][0][1]/1000.0, \
+                        vad_segments = [[sorted_data[beg_idx:end_idx][_b][0][0]/1000.0,
+                                        sorted_data[beg_idx:end_idx][_b][0][1]/1000.0,
                                         speech_j[_b]]]
                         segments = sv_chunk(vad_segments)
                         all_segments.extend(segments)
@@ -338,12 +345,13 @@
                 results_sorted.extend(results)
 
 
-            pbar_total.update(1)
+            
             end_asr_total = time.time()
             time_escape_total_per_sample = end_asr_total - beg_asr_total
-            pbar_total.set_description(f"rtf_avg_per_sample: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, "
+            pbar_sample.set_description(f"rtf_avg_per_sample: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, "
                                  f"time_speech_total_per_sample: {time_speech_total_per_sample: 0.3f}, "
                                  f"time_escape_total_per_sample: {time_escape_total_per_sample:0.3f}")
+            
 
             restored_data = [0] * n
             for j in range(n):
diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index 7ae687e..0334006 100644
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -141,30 +141,37 @@
     scheduler_class = scheduler_classes.get(scheduler)
     scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf"))
 
-    # import pdb;
-    # pdb.set_trace()
+
     # dataset
     dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset"))
     dataset_tr = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, **kwargs.get("dataset_conf"))
+    dataset_val = dataset_class(kwargs.get("valid_data_set_list"), frontend=frontend, tokenizer=tokenizer,
+                               **kwargs.get("dataset_conf"))
 
     # dataloader
     batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler")
-    batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler)
+    batch_sampler_val = None
     if batch_sampler is not None:
+        batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler)
         batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf"))
+        batch_sampler_val = batch_sampler_class(dataset_tr, is_training=False, **kwargs.get("dataset_conf"))
     dataloader_tr = torch.utils.data.DataLoader(dataset_tr,
                                                 collate_fn=dataset_tr.collator,
                                                 batch_sampler=batch_sampler,
                                                 num_workers=kwargs.get("dataset_conf").get("num_workers", 4),
                                                 pin_memory=True)
     
-
+    dataloader_val = torch.utils.data.DataLoader(dataset_val,
+                                                collate_fn=dataset_val.collator,
+                                                batch_sampler=batch_sampler_val,
+                                                num_workers=kwargs.get("dataset_conf").get("num_workers", 4),
+                                                pin_memory=True)
     trainer = Trainer(
         model=model,
         optim=optim,
         scheduler=scheduler,
         dataloader_train=dataloader_tr,
-        dataloader_val=None,
+        dataloader_val=dataloader_val,
         local_rank=local_rank,
         use_ddp=use_ddp,
         use_fsdp=use_fsdp,
diff --git a/funasr/datasets/audio_datasets/index_ds.py b/funasr/datasets/audio_datasets/index_ds.py
index 8e5b05c..c94d209 100644
--- a/funasr/datasets/audio_datasets/index_ds.py
+++ b/funasr/datasets/audio_datasets/index_ds.py
@@ -54,7 +54,11 @@
         return len(self.contents)
     
     def __getitem__(self, index):
-        return self.contents[index]
+        try:
+            data = self.contents[index]
+        except:
+            print(index)
+        return data
     
     def get_source_len(self, data_dict):
         return data_dict["source_len"]
diff --git a/funasr/datasets/audio_datasets/samplers.py b/funasr/datasets/audio_datasets/samplers.py
index 4af35e9..e170c68 100644
--- a/funasr/datasets/audio_datasets/samplers.py
+++ b/funasr/datasets/audio_datasets/samplers.py
@@ -13,6 +13,7 @@
                  buffer_size: int = 30,
                  drop_last: bool = False,
                  shuffle: bool = True,
+                 is_training: bool = True,
                  **kwargs):
         
         self.drop_last = drop_last
@@ -24,7 +25,7 @@
         self.buffer_size = buffer_size
         self.max_token_length = kwargs.get("max_token_length", 5000)
         self.shuffle_idx = np.arange(self.total_samples)
-        self.shuffle = shuffle
+        self.shuffle = shuffle and is_training
     
     def __len__(self):
         return self.total_samples
diff --git a/funasr/models/paraformer/model.py b/funasr/models/paraformer/model.py
index f92441d..9f3c3f3 100644
--- a/funasr/models/paraformer/model.py
+++ b/funasr/models/paraformer/model.py
@@ -164,6 +164,7 @@
         self.use_1st_decoder_loss = use_1st_decoder_loss
         self.length_normalized_loss = length_normalized_loss
         self.beam_search = None
+        self.error_calculator = None
     
     def forward(
         self,
diff --git a/funasr/models/paraformer/template.yaml b/funasr/models/paraformer/template.yaml
index 94eebf7..3972caa 100644
--- a/funasr/models/paraformer/template.yaml
+++ b/funasr/models/paraformer/template.yaml
@@ -95,6 +95,7 @@
       - acc
       - max
   keep_nbest_models: 10
+  avg_nbest_model: 5
   log_interval: 50
 
 optim: adam
diff --git a/funasr/train_utils/average_nbest_models.py b/funasr/train_utils/average_nbest_models.py
index 96e1384..f117804 100644
--- a/funasr/train_utils/average_nbest_models.py
+++ b/funasr/train_utils/average_nbest_models.py
@@ -9,117 +9,173 @@
 
 import torch
 from typing import Collection
+import os
+import torch
+import re
+from collections import OrderedDict
+from functools import cmp_to_key
 
-from funasr.train.reporter import Reporter
 
+# @torch.no_grad()
+# def average_nbest_models(
+#     output_dir: Path,
+#     best_model_criterion: Sequence[Sequence[str]],
+#     nbest: Union[Collection[int], int],
+#     suffix: Optional[str] = None,
+#     oss_bucket=None,
+#     pai_output_dir=None,
+# ) -> None:
+#     """Generate averaged model from n-best models
+#
+#     Args:
+#         output_dir: The directory contains the model file for each epoch
+#         reporter: Reporter instance
+#         best_model_criterion: Give criterions to decide the best model.
+#             e.g. [("valid", "loss", "min"), ("train", "acc", "max")]
+#         nbest: Number of best model files to be averaged
+#         suffix: A suffix added to the averaged model file name
+#     """
+#     if isinstance(nbest, int):
+#         nbests = [nbest]
+#     else:
+#         nbests = list(nbest)
+#     if len(nbests) == 0:
+#         warnings.warn("At least 1 nbest values are required")
+#         nbests = [1]
+#     if suffix is not None:
+#         suffix = suffix + "."
+#     else:
+#         suffix = ""
+#
+#     # 1. Get nbests: List[Tuple[str, str, List[Tuple[epoch, value]]]]
+#     nbest_epochs = [
+#         (ph, k, reporter.sort_epochs_and_values(ph, k, m)[: max(nbests)])
+#         for ph, k, m in best_model_criterion
+#         if reporter.has(ph, k)
+#     ]
+#
+#     _loaded = {}
+#     for ph, cr, epoch_and_values in nbest_epochs:
+#         _nbests = [i for i in nbests if i <= len(epoch_and_values)]
+#         if len(_nbests) == 0:
+#             _nbests = [1]
+#
+#         for n in _nbests:
+#             if n == 0:
+#                 continue
+#             elif n == 1:
+#                 # The averaged model is same as the best model
+#                 e, _ = epoch_and_values[0]
+#                 op = output_dir / f"{e}epoch.pb"
+#                 sym_op = output_dir / f"{ph}.{cr}.ave_1best.{suffix}pb"
+#                 if sym_op.is_symlink() or sym_op.exists():
+#                     sym_op.unlink()
+#                 sym_op.symlink_to(op.name)
+#             else:
+#                 op = output_dir / f"{ph}.{cr}.ave_{n}best.{suffix}pb"
+#                 logging.info(
+#                     f"Averaging {n}best models: " f'criterion="{ph}.{cr}": {op}'
+#                 )
+#
+#                 avg = None
+#                 # 2.a. Averaging model
+#                 for e, _ in epoch_and_values[:n]:
+#                     if e not in _loaded:
+#                         if oss_bucket is None:
+#                             _loaded[e] = torch.load(
+#                                 output_dir / f"{e}epoch.pb",
+#                                 map_location="cpu",
+#                             )
+#                         else:
+#                             buffer = BytesIO(
+#                                 oss_bucket.get_object(os.path.join(pai_output_dir, f"{e}epoch.pb")).read())
+#                             _loaded[e] = torch.load(buffer)
+#                     states = _loaded[e]
+#
+#                     if avg is None:
+#                         avg = states
+#                     else:
+#                         # Accumulated
+#                         for k in avg:
+#                             avg[k] = avg[k] + states[k]
+#                 for k in avg:
+#                     if str(avg[k].dtype).startswith("torch.int"):
+#                         # For int type, not averaged, but only accumulated.
+#                         # e.g. BatchNorm.num_batches_tracked
+#                         # (If there are any cases that requires averaging
+#                         #  or the other reducing method, e.g. max/min, for integer type,
+#                         #  please report.)
+#                         pass
+#                     else:
+#                         avg[k] = avg[k] / n
+#
+#                 # 2.b. Save the ave model and create a symlink
+#                 if oss_bucket is None:
+#                     torch.save(avg, op)
+#                 else:
+#                     buffer = BytesIO()
+#                     torch.save(avg, buffer)
+#                     oss_bucket.put_object(os.path.join(pai_output_dir, f"{ph}.{cr}.ave_{n}best.{suffix}pb"),
+#                                           buffer.getvalue())
+#
+#         # 3. *.*.ave.pb is a symlink to the max ave model
+#         if oss_bucket is None:
+#             op = output_dir / f"{ph}.{cr}.ave_{max(_nbests)}best.{suffix}pb"
+#             sym_op = output_dir / f"{ph}.{cr}.ave.{suffix}pb"
+#             if sym_op.is_symlink() or sym_op.exists():
+#                 sym_op.unlink()
+#             sym_op.symlink_to(op.name)
+
+
+def _get_checkpoint_paths(output_dir: str, last_n: int=5):
+    """
+    Get the paths of the last 'last_n' checkpoints by parsing filenames
+    in the output directory.
+    """
+    # List all files in the output directory
+    files = os.listdir(output_dir)
+    # Filter out checkpoint files and extract epoch numbers
+    checkpoint_files = [f for f in files if f.startswith("model.pt.e")]
+    # Sort files by epoch number in descending order
+    checkpoint_files.sort(key=lambda x: int(re.search(r'(\d+)', x).group()), reverse=True)
+    # Get the last 'last_n' checkpoint paths
+    checkpoint_paths = [os.path.join(output_dir, f) for f in checkpoint_files[:last_n]]
+    return checkpoint_paths
 
 @torch.no_grad()
-def average_nbest_models(
-    output_dir: Path,
-    reporter: Reporter,
-    best_model_criterion: Sequence[Sequence[str]],
-    nbest: Union[Collection[int], int],
-    suffix: Optional[str] = None,
-    oss_bucket=None,
-    pai_output_dir=None,
-) -> None:
-    """Generate averaged model from n-best models
-
-    Args:
-        output_dir: The directory contains the model file for each epoch
-        reporter: Reporter instance
-        best_model_criterion: Give criterions to decide the best model.
-            e.g. [("valid", "loss", "min"), ("train", "acc", "max")]
-        nbest: Number of best model files to be averaged
-        suffix: A suffix added to the averaged model file name
+def average_checkpoints(output_dir: str, last_n: int=5):
     """
-    if isinstance(nbest, int):
-        nbests = [nbest]
-    else:
-        nbests = list(nbest)
-    if len(nbests) == 0:
-        warnings.warn("At least 1 nbest values are required")
-        nbests = [1]
-    if suffix is not None:
-        suffix = suffix + "."
-    else:
-        suffix = ""
+    Average the last 'last_n' checkpoints' model state_dicts.
+    If a tensor is of type torch.int, perform sum instead of average.
+    """
+    checkpoint_paths = _get_checkpoint_paths(output_dir, last_n)
+    state_dicts = []
 
-    # 1. Get nbests: List[Tuple[str, str, List[Tuple[epoch, value]]]]
-    nbest_epochs = [
-        (ph, k, reporter.sort_epochs_and_values(ph, k, m)[: max(nbests)])
-        for ph, k, m in best_model_criterion
-        if reporter.has(ph, k)
-    ]
+    # Load state_dicts from checkpoints
+    for path in checkpoint_paths:
+        if os.path.isfile(path):
+            state_dicts.append(torch.load(path, map_location='cpu')['state_dict'])
+        else:
+            print(f"Checkpoint file {path} not found.")
+            continue
 
-    _loaded = {}
-    for ph, cr, epoch_and_values in nbest_epochs:
-        _nbests = [i for i in nbests if i <= len(epoch_and_values)]
-        if len(_nbests) == 0:
-            _nbests = [1]
+    # Check if we have any state_dicts to average
+    if not state_dicts:
+        raise RuntimeError("No checkpoints found for averaging.")
 
-        for n in _nbests:
-            if n == 0:
-                continue
-            elif n == 1:
-                # The averaged model is same as the best model
-                e, _ = epoch_and_values[0]
-                op = output_dir / f"{e}epoch.pb"
-                sym_op = output_dir / f"{ph}.{cr}.ave_1best.{suffix}pb"
-                if sym_op.is_symlink() or sym_op.exists():
-                    sym_op.unlink()
-                sym_op.symlink_to(op.name)
-            else:
-                op = output_dir / f"{ph}.{cr}.ave_{n}best.{suffix}pb"
-                logging.info(
-                    f"Averaging {n}best models: " f'criterion="{ph}.{cr}": {op}'
-                )
-
-                avg = None
-                # 2.a. Averaging model
-                for e, _ in epoch_and_values[:n]:
-                    if e not in _loaded:
-                        if oss_bucket is None:
-                            _loaded[e] = torch.load(
-                                output_dir / f"{e}epoch.pb",
-                                map_location="cpu",
-                            )
-                        else:
-                            buffer = BytesIO(
-                                oss_bucket.get_object(os.path.join(pai_output_dir, f"{e}epoch.pb")).read())
-                            _loaded[e] = torch.load(buffer)
-                    states = _loaded[e]
-
-                    if avg is None:
-                        avg = states
-                    else:
-                        # Accumulated
-                        for k in avg:
-                            avg[k] = avg[k] + states[k]
-                for k in avg:
-                    if str(avg[k].dtype).startswith("torch.int"):
-                        # For int type, not averaged, but only accumulated.
-                        # e.g. BatchNorm.num_batches_tracked
-                        # (If there are any cases that requires averaging
-                        #  or the other reducing method, e.g. max/min, for integer type,
-                        #  please report.)
-                        pass
-                    else:
-                        avg[k] = avg[k] / n
-
-                # 2.b. Save the ave model and create a symlink
-                if oss_bucket is None:
-                    torch.save(avg, op)
-                else:
-                    buffer = BytesIO()
-                    torch.save(avg, buffer)
-                    oss_bucket.put_object(os.path.join(pai_output_dir, f"{ph}.{cr}.ave_{n}best.{suffix}pb"),
-                                          buffer.getvalue())
-
-        # 3. *.*.ave.pb is a symlink to the max ave model
-        if oss_bucket is None:
-            op = output_dir / f"{ph}.{cr}.ave_{max(_nbests)}best.{suffix}pb"
-            sym_op = output_dir / f"{ph}.{cr}.ave.{suffix}pb"
-            if sym_op.is_symlink() or sym_op.exists():
-                sym_op.unlink()
-            sym_op.symlink_to(op.name)
+    # Average or sum weights
+    avg_state_dict = OrderedDict()
+    for key in state_dicts[0].keys():
+        tensors = [state_dict[key].cpu() for state_dict in state_dicts]
+        # Check the type of the tensor
+        if str(tensors[0].dtype).startswith("torch.int"):
+            # Perform sum for integer tensors
+            summed_tensor = sum(tensors)
+            avg_state_dict[key] = summed_tensor
+        else:
+            # Perform average for other types of tensors
+            stacked_tensors = torch.stack(tensors)
+            avg_state_dict[key] = torch.mean(stacked_tensors, dim=0)
+    
+    torch.save({'state_dict': avg_state_dict}, os.path.join(output_dir, f"model.pt.avg{last_n}"))
+    return avg_state_dict
\ No newline at end of file
diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py
index da346c3..91b30b0 100644
--- a/funasr/train_utils/trainer.py
+++ b/funasr/train_utils/trainer.py
@@ -7,10 +7,11 @@
 from contextlib import nullcontext
 # from torch.utils.tensorboard import SummaryWriter
 from tensorboardX import SummaryWriter
+from pathlib import Path
 
 from funasr.train_utils.device_funcs import to_device
 from funasr.train_utils.recursive_op import recursive_average
-
+from funasr.train_utils.average_nbest_models import average_checkpoints
 
 class Trainer:
     """
@@ -66,10 +67,9 @@
         self.use_ddp = use_ddp
         self.use_fsdp = use_fsdp
         self.device = next(model.parameters()).device
+        self.avg_nbest_model = kwargs.get("avg_nbest_model", 5)
         self.kwargs = kwargs
         
-        if self.resume:
-            self._resume_checkpoint(self.resume)
     
         try:
             rank = dist.get_rank()
@@ -102,9 +102,17 @@
         }
         # Create output directory if it does not exist
         os.makedirs(self.output_dir, exist_ok=True)
-        filename = os.path.join(self.output_dir, f'model.e{epoch}.pb')
+        filename = os.path.join(self.output_dir, f'model.pt.ep{epoch}')
         torch.save(state, filename)
+        
         print(f'Checkpoint saved to {filename}')
+        latest = Path(os.path.join(self.output_dir, f'model.pt'))
+        try:
+            latest.unlink()
+        except:
+            pass
+
+        latest.symlink_to(filename)
     
     def _resume_checkpoint(self, resume_path):
         """
@@ -114,29 +122,50 @@
         Args:
             resume_path (str): The file path to the checkpoint to resume from.
         """
-        if os.path.isfile(resume_path):
-            checkpoint = torch.load(resume_path)
+        ckpt = os.path.join(resume_path, "model.pt")
+        if os.path.isfile(ckpt):
+            checkpoint = torch.load(ckpt)
             self.start_epoch = checkpoint['epoch'] + 1
             self.model.load_state_dict(checkpoint['state_dict'])
             self.optim.load_state_dict(checkpoint['optimizer'])
             self.scheduler.load_state_dict(checkpoint['scheduler'])
-            print(f"Checkpoint loaded successfully from '{resume_path}' at (epoch {checkpoint['epoch']})")
+            print(f"Checkpoint loaded successfully from '{ckpt}'")
         else:
-            print(f"No checkpoint found at '{resume_path}', starting from scratch")
+            print(f"No checkpoint found at '{ckpt}', starting from scratch")
+
+        if self.use_ddp or self.use_fsdp:
+            dist.barrier()
         
     def run(self):
         """
         Starts the training process, iterating over epochs, training the model,
         and saving checkpoints at the end of each epoch.
         """
+        if self.resume:
+            self._resume_checkpoint(self.output_dir)
+        
         for epoch in range(self.start_epoch, self.max_epoch + 1):
+            
             self._train_epoch(epoch)
-            # self._validate_epoch(epoch)
+            
+            self._validate_epoch(epoch)
+            
             if self.rank == 0:
                 self._save_checkpoint(epoch)
-            self.scheduler.step()
             
+            if self.use_ddp or self.use_fsdp:
+                dist.barrier()
+            
+            self.scheduler.step()
+
+
+        if self.rank == 0:
+            average_checkpoints(self.output_dir, self.avg_nbest_model)
+            
+        if self.use_ddp or self.use_fsdp:
+            dist.barrier()
         self.writer.close()
+        
     
     def _train_epoch(self, epoch):
         """
@@ -157,8 +186,7 @@
         for batch_idx, batch in enumerate(self.dataloader_train):
             time1 = time.perf_counter()
             speed_stats["data_load"] = f"{time1-time5:0.3f}"
-            # import pdb;
-            # pdb.set_trace()
+
             batch = to_device(batch, self.device)
             
             my_context = self.model.no_sync if batch_idx % accum_grad != 0 else nullcontext
@@ -211,13 +239,12 @@
                 speed_stats["optim_time"] = f"{time5 - time4:0.3f}"
     
                 speed_stats["total_time"] = total_time
-            
-            # import pdb;
-            # pdb.set_trace()
+
+
             pbar.update(1)
             if self.local_rank == 0:
                 description = (
-                    f"Epoch: {epoch + 1}/{self.max_epoch}, "
+                    f"Epoch: {epoch}/{self.max_epoch}, "
                     f"step {batch_idx}/{len(self.dataloader_train)}, "
                     f"{speed_stats}, "
                     f"(loss: {loss.detach().cpu().item():.3f}), "
@@ -248,6 +275,50 @@
         """
         self.model.eval()
         with torch.no_grad():
-            for data, target in self.dataloader_val:
-                # Implement the model validation steps here
-                pass
+            pbar = tqdm(colour="red", desc=f"Training Epoch: {epoch + 1}", total=len(self.dataloader_val),
+                        dynamic_ncols=True)
+            speed_stats = {}
+            time5 = time.perf_counter()
+            for batch_idx, batch in enumerate(self.dataloader_val):
+                time1 = time.perf_counter()
+                speed_stats["data_load"] = f"{time1 - time5:0.3f}"
+                batch = to_device(batch, self.device)
+                time2 = time.perf_counter()
+                retval = self.model(**batch)
+                time3 = time.perf_counter()
+                speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
+                loss, stats, weight = retval
+                stats = {k: v for k, v in stats.items() if v is not None}
+                if self.use_ddp or self.use_fsdp:
+                    # Apply weighted averaging for loss and stats
+                    loss = (loss * weight.type(loss.dtype)).sum()
+                    # if distributed, this method can also apply all_reduce()
+                    stats, weight = recursive_average(stats, weight, distributed=True)
+                    # Now weight is summation over all workers
+                    loss /= weight
+                    # Multiply world_size because DistributedDataParallel
+                    # automatically normalizes the gradient by world_size.
+                    loss *= self.world_size
+                # Scale the loss since we're not updating for every mini-batch
+                loss = loss
+                time4 = time.perf_counter()
+
+                pbar.update(1)
+                if self.local_rank == 0:
+                    description = (
+                        f"validation: \nEpoch: {epoch}/{self.max_epoch}, "
+                        f"step {batch_idx}/{len(self.dataloader_train)}, "
+                        f"{speed_stats}, "
+                        f"(loss: {loss.detach().cpu().item():.3f}), "
+                        f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}"
+                    )
+                    pbar.set_description(description)
+                    if self.writer:
+                        self.writer.add_scalar('Loss/val', loss.item(),
+                                               epoch*len(self.dataloader_train) + batch_idx)
+                        for key, var in stats.items():
+                            self.writer.add_scalar(f'{key}/val', var.item(),
+                                                   epoch * len(self.dataloader_train) + batch_idx)
+                        for key, var in speed_stats.items():
+                            self.writer.add_scalar(f'{key}/val', eval(var),
+                                                   epoch * len(self.dataloader_train) + batch_idx)
\ No newline at end of file

--
Gitblit v1.9.1