zhifu gao
2024-01-17 9a9c3b75b5b3359701844a91a9fae6d2979866cd
Funasr1.0 (#1261)

* funasr1.0 funetine

* funasr1.0 pbar

* update with main (#1260)

* Update websocket_protocol_zh.md

* update

---------

Co-authored-by: Yabin Li <wucong.lyb@alibaba-inc.com>
Co-authored-by: shixian.shi <shixian.shi@alibaba-inc.com>

---------

Co-authored-by: Yabin Li <wucong.lyb@alibaba-inc.com>
Co-authored-by: shixian.shi <shixian.shi@alibaba-inc.com>
10个文件已修改
445 ■■■■■ 已修改文件
examples/industrial_data_pretraining/paraformer/finetune.sh 4 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/seaco_paraformer/demo.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/auto/auto_model.py 34 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/train.py 17 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/audio_datasets/index_ds.py 6 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/audio_datasets/samplers.py 3 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/paraformer/model.py 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/paraformer/template.yaml 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/train_utils/average_nbest_models.py 268 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/train_utils/trainer.py 109 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/paraformer/finetune.sh
@@ -9,9 +9,11 @@
python funasr/bin/train.py \
+model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \
+model_revision="v2.0.2" \
+train_data_set_list="/Users/zhifu/funasr_github/test_local/aishell2_dev_ios/asr_task_debug_len.jsonl" \
+train_data_set_list="/Users/zhifu/funasr_github/test_local/aishell2_dev_ios/asr_task_debug_len_10.jsonl" \
+valid_data_set_list="/Users/zhifu/funasr_github/test_local/aishell2_dev_ios/asr_task_debug_len_10.jsonl" \
++dataset_conf.batch_size=2 \
++dataset_conf.batch_type="example" \
++train_conf.max_epoch=2 \
+output_dir="outputs/debug/ckpt/funasr2/exp2" \
+device="cpu" \
+debug="true"
examples/industrial_data_pretraining/seaco_paraformer/demo.py
@@ -15,6 +15,6 @@
                  spk_model_revision="v2.0.2",
                  )
res = model.generate(input=f"{model.model_path}/example/asr_example.wav",
res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
                     hotword='达摩院 魔搭')
print(res)
funasr/auto/auto_model.py
@@ -221,7 +221,8 @@
        speed_stats = {}
        asr_result_list = []
        num_samples = len(data_list)
        pbar = tqdm(colour="blue", total=num_samples+1, dynamic_ncols=True)
        disable_pbar = kwargs.get("disable_pbar", False)
        pbar = tqdm(colour="blue", total=num_samples+1, dynamic_ncols=True) if not disable_pbar else None
        time_speech_total = 0.0
        time_escape_total = 0.0
        for beg_idx in range(0, num_samples, batch_size):
@@ -239,8 +240,7 @@
            time2 = time.perf_counter()
            
            asr_result_list.extend(results)
            pbar.update(1)
            # batch_data_time = time_per_frame_s * data_batch_i["speech_lengths"].sum().item()
            batch_data_time = meta_data.get("batch_data_time", -1)
            time_escape = time2 - time1
@@ -252,12 +252,15 @@
            description = (
                f"{speed_stats}, "
            )
            pbar.set_description(description)
            if pbar:
                pbar.update(1)
                pbar.set_description(description)
            time_speech_total += batch_data_time
            time_escape_total += time_escape
        pbar.update(1)
        pbar.set_description(f"rtf_avg: {time_escape_total/time_speech_total:0.3f}")
        if pbar:
            pbar.update(1)
            pbar.set_description(f"rtf_avg: {time_escape_total/time_speech_total:0.3f}")
        torch.cuda.empty_cache()
        return asr_result_list
    
@@ -309,8 +312,11 @@
            time_speech_total_per_sample = speech_lengths/16000
            time_speech_total_all_samples += time_speech_total_per_sample
            pbar_sample = tqdm(colour="blue", total=n + 1, dynamic_ncols=True)
            all_segments = []
            for j, _ in enumerate(range(0, n)):
                pbar_sample.update(1)
                batch_size_ms_cum += (sorted_data[j][0][1] - sorted_data[j][0][0])
                if j < n - 1 and (
                    batch_size_ms_cum + sorted_data[j + 1][0][1] - sorted_data[j + 1][0][0]) < batch_size and (
@@ -319,13 +325,14 @@
                batch_size_ms_cum = 0
                end_idx = j + 1
                speech_j, speech_lengths_j = slice_padding_audio_samples(speech, speech_lengths, sorted_data[beg_idx:end_idx])       
                results = self.inference(speech_j, input_len=None, model=model, kwargs=kwargs, **cfg)
                results = self.inference(speech_j, input_len=None, model=model, kwargs=kwargs, disable_pbar=True, **cfg)
                if self.spk_model is not None:
                    # compose vad segments: [[start_time_sec, end_time_sec, speech], [...]]
                    for _b in range(len(speech_j)):
                        vad_segments = [[sorted_data[beg_idx:end_idx][_b][0][0]/1000.0, \
                                        sorted_data[beg_idx:end_idx][_b][0][1]/1000.0, \
                        vad_segments = [[sorted_data[beg_idx:end_idx][_b][0][0]/1000.0,
                                        sorted_data[beg_idx:end_idx][_b][0][1]/1000.0,
                                        speech_j[_b]]]
                        segments = sv_chunk(vad_segments)
                        all_segments.extend(segments)
@@ -338,12 +345,13 @@
                results_sorted.extend(results)
            pbar_total.update(1)
            end_asr_total = time.time()
            time_escape_total_per_sample = end_asr_total - beg_asr_total
            pbar_total.set_description(f"rtf_avg_per_sample: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, "
            pbar_sample.set_description(f"rtf_avg_per_sample: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, "
                                 f"time_speech_total_per_sample: {time_speech_total_per_sample: 0.3f}, "
                                 f"time_escape_total_per_sample: {time_escape_total_per_sample:0.3f}")
            restored_data = [0] * n
            for j in range(n):
funasr/bin/train.py
@@ -141,30 +141,37 @@
    scheduler_class = scheduler_classes.get(scheduler)
    scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf"))
    # import pdb;
    # pdb.set_trace()
    # dataset
    dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset"))
    dataset_tr = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, **kwargs.get("dataset_conf"))
    dataset_val = dataset_class(kwargs.get("valid_data_set_list"), frontend=frontend, tokenizer=tokenizer,
                               **kwargs.get("dataset_conf"))
    # dataloader
    batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler")
    batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler)
    batch_sampler_val = None
    if batch_sampler is not None:
        batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler)
        batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf"))
        batch_sampler_val = batch_sampler_class(dataset_tr, is_training=False, **kwargs.get("dataset_conf"))
    dataloader_tr = torch.utils.data.DataLoader(dataset_tr,
                                                collate_fn=dataset_tr.collator,
                                                batch_sampler=batch_sampler,
                                                num_workers=kwargs.get("dataset_conf").get("num_workers", 4),
                                                pin_memory=True)
    
    dataloader_val = torch.utils.data.DataLoader(dataset_val,
                                                collate_fn=dataset_val.collator,
                                                batch_sampler=batch_sampler_val,
                                                num_workers=kwargs.get("dataset_conf").get("num_workers", 4),
                                                pin_memory=True)
    trainer = Trainer(
        model=model,
        optim=optim,
        scheduler=scheduler,
        dataloader_train=dataloader_tr,
        dataloader_val=None,
        dataloader_val=dataloader_val,
        local_rank=local_rank,
        use_ddp=use_ddp,
        use_fsdp=use_fsdp,
funasr/datasets/audio_datasets/index_ds.py
@@ -54,7 +54,11 @@
        return len(self.contents)
    
    def __getitem__(self, index):
        return self.contents[index]
        try:
            data = self.contents[index]
        except:
            print(index)
        return data
    
    def get_source_len(self, data_dict):
        return data_dict["source_len"]
funasr/datasets/audio_datasets/samplers.py
@@ -13,6 +13,7 @@
                 buffer_size: int = 30,
                 drop_last: bool = False,
                 shuffle: bool = True,
                 is_training: bool = True,
                 **kwargs):
        
        self.drop_last = drop_last
@@ -24,7 +25,7 @@
        self.buffer_size = buffer_size
        self.max_token_length = kwargs.get("max_token_length", 5000)
        self.shuffle_idx = np.arange(self.total_samples)
        self.shuffle = shuffle
        self.shuffle = shuffle and is_training
    
    def __len__(self):
        return self.total_samples
funasr/models/paraformer/model.py
@@ -164,6 +164,7 @@
        self.use_1st_decoder_loss = use_1st_decoder_loss
        self.length_normalized_loss = length_normalized_loss
        self.beam_search = None
        self.error_calculator = None
    
    def forward(
        self,
funasr/models/paraformer/template.yaml
@@ -95,6 +95,7 @@
      - acc
      - max
  keep_nbest_models: 10
  avg_nbest_model: 5
  log_interval: 50
optim: adam
funasr/train_utils/average_nbest_models.py
@@ -9,117 +9,173 @@
import torch
from typing import Collection
import os
import torch
import re
from collections import OrderedDict
from functools import cmp_to_key
from funasr.train.reporter import Reporter
# @torch.no_grad()
# def average_nbest_models(
#     output_dir: Path,
#     best_model_criterion: Sequence[Sequence[str]],
#     nbest: Union[Collection[int], int],
#     suffix: Optional[str] = None,
#     oss_bucket=None,
#     pai_output_dir=None,
# ) -> None:
#     """Generate averaged model from n-best models
#
#     Args:
#         output_dir: The directory contains the model file for each epoch
#         reporter: Reporter instance
#         best_model_criterion: Give criterions to decide the best model.
#             e.g. [("valid", "loss", "min"), ("train", "acc", "max")]
#         nbest: Number of best model files to be averaged
#         suffix: A suffix added to the averaged model file name
#     """
#     if isinstance(nbest, int):
#         nbests = [nbest]
#     else:
#         nbests = list(nbest)
#     if len(nbests) == 0:
#         warnings.warn("At least 1 nbest values are required")
#         nbests = [1]
#     if suffix is not None:
#         suffix = suffix + "."
#     else:
#         suffix = ""
#
#     # 1. Get nbests: List[Tuple[str, str, List[Tuple[epoch, value]]]]
#     nbest_epochs = [
#         (ph, k, reporter.sort_epochs_and_values(ph, k, m)[: max(nbests)])
#         for ph, k, m in best_model_criterion
#         if reporter.has(ph, k)
#     ]
#
#     _loaded = {}
#     for ph, cr, epoch_and_values in nbest_epochs:
#         _nbests = [i for i in nbests if i <= len(epoch_and_values)]
#         if len(_nbests) == 0:
#             _nbests = [1]
#
#         for n in _nbests:
#             if n == 0:
#                 continue
#             elif n == 1:
#                 # The averaged model is same as the best model
#                 e, _ = epoch_and_values[0]
#                 op = output_dir / f"{e}epoch.pb"
#                 sym_op = output_dir / f"{ph}.{cr}.ave_1best.{suffix}pb"
#                 if sym_op.is_symlink() or sym_op.exists():
#                     sym_op.unlink()
#                 sym_op.symlink_to(op.name)
#             else:
#                 op = output_dir / f"{ph}.{cr}.ave_{n}best.{suffix}pb"
#                 logging.info(
#                     f"Averaging {n}best models: " f'criterion="{ph}.{cr}": {op}'
#                 )
#
#                 avg = None
#                 # 2.a. Averaging model
#                 for e, _ in epoch_and_values[:n]:
#                     if e not in _loaded:
#                         if oss_bucket is None:
#                             _loaded[e] = torch.load(
#                                 output_dir / f"{e}epoch.pb",
#                                 map_location="cpu",
#                             )
#                         else:
#                             buffer = BytesIO(
#                                 oss_bucket.get_object(os.path.join(pai_output_dir, f"{e}epoch.pb")).read())
#                             _loaded[e] = torch.load(buffer)
#                     states = _loaded[e]
#
#                     if avg is None:
#                         avg = states
#                     else:
#                         # Accumulated
#                         for k in avg:
#                             avg[k] = avg[k] + states[k]
#                 for k in avg:
#                     if str(avg[k].dtype).startswith("torch.int"):
#                         # For int type, not averaged, but only accumulated.
#                         # e.g. BatchNorm.num_batches_tracked
#                         # (If there are any cases that requires averaging
#                         #  or the other reducing method, e.g. max/min, for integer type,
#                         #  please report.)
#                         pass
#                     else:
#                         avg[k] = avg[k] / n
#
#                 # 2.b. Save the ave model and create a symlink
#                 if oss_bucket is None:
#                     torch.save(avg, op)
#                 else:
#                     buffer = BytesIO()
#                     torch.save(avg, buffer)
#                     oss_bucket.put_object(os.path.join(pai_output_dir, f"{ph}.{cr}.ave_{n}best.{suffix}pb"),
#                                           buffer.getvalue())
#
#         # 3. *.*.ave.pb is a symlink to the max ave model
#         if oss_bucket is None:
#             op = output_dir / f"{ph}.{cr}.ave_{max(_nbests)}best.{suffix}pb"
#             sym_op = output_dir / f"{ph}.{cr}.ave.{suffix}pb"
#             if sym_op.is_symlink() or sym_op.exists():
#                 sym_op.unlink()
#             sym_op.symlink_to(op.name)
def _get_checkpoint_paths(output_dir: str, last_n: int=5):
    """
    Get the paths of the last 'last_n' checkpoints by parsing filenames
    in the output directory.
    """
    # List all files in the output directory
    files = os.listdir(output_dir)
    # Filter out checkpoint files and extract epoch numbers
    checkpoint_files = [f for f in files if f.startswith("model.pt.e")]
    # Sort files by epoch number in descending order
    checkpoint_files.sort(key=lambda x: int(re.search(r'(\d+)', x).group()), reverse=True)
    # Get the last 'last_n' checkpoint paths
    checkpoint_paths = [os.path.join(output_dir, f) for f in checkpoint_files[:last_n]]
    return checkpoint_paths
@torch.no_grad()
def average_nbest_models(
    output_dir: Path,
    reporter: Reporter,
    best_model_criterion: Sequence[Sequence[str]],
    nbest: Union[Collection[int], int],
    suffix: Optional[str] = None,
    oss_bucket=None,
    pai_output_dir=None,
) -> None:
    """Generate averaged model from n-best models
    Args:
        output_dir: The directory contains the model file for each epoch
        reporter: Reporter instance
        best_model_criterion: Give criterions to decide the best model.
            e.g. [("valid", "loss", "min"), ("train", "acc", "max")]
        nbest: Number of best model files to be averaged
        suffix: A suffix added to the averaged model file name
def average_checkpoints(output_dir: str, last_n: int=5):
    """
    if isinstance(nbest, int):
        nbests = [nbest]
    else:
        nbests = list(nbest)
    if len(nbests) == 0:
        warnings.warn("At least 1 nbest values are required")
        nbests = [1]
    if suffix is not None:
        suffix = suffix + "."
    else:
        suffix = ""
    Average the last 'last_n' checkpoints' model state_dicts.
    If a tensor is of type torch.int, perform sum instead of average.
    """
    checkpoint_paths = _get_checkpoint_paths(output_dir, last_n)
    state_dicts = []
    # 1. Get nbests: List[Tuple[str, str, List[Tuple[epoch, value]]]]
    nbest_epochs = [
        (ph, k, reporter.sort_epochs_and_values(ph, k, m)[: max(nbests)])
        for ph, k, m in best_model_criterion
        if reporter.has(ph, k)
    ]
    # Load state_dicts from checkpoints
    for path in checkpoint_paths:
        if os.path.isfile(path):
            state_dicts.append(torch.load(path, map_location='cpu')['state_dict'])
        else:
            print(f"Checkpoint file {path} not found.")
            continue
    _loaded = {}
    for ph, cr, epoch_and_values in nbest_epochs:
        _nbests = [i for i in nbests if i <= len(epoch_and_values)]
        if len(_nbests) == 0:
            _nbests = [1]
    # Check if we have any state_dicts to average
    if not state_dicts:
        raise RuntimeError("No checkpoints found for averaging.")
        for n in _nbests:
            if n == 0:
                continue
            elif n == 1:
                # The averaged model is same as the best model
                e, _ = epoch_and_values[0]
                op = output_dir / f"{e}epoch.pb"
                sym_op = output_dir / f"{ph}.{cr}.ave_1best.{suffix}pb"
                if sym_op.is_symlink() or sym_op.exists():
                    sym_op.unlink()
                sym_op.symlink_to(op.name)
            else:
                op = output_dir / f"{ph}.{cr}.ave_{n}best.{suffix}pb"
                logging.info(
                    f"Averaging {n}best models: " f'criterion="{ph}.{cr}": {op}'
                )
                avg = None
                # 2.a. Averaging model
                for e, _ in epoch_and_values[:n]:
                    if e not in _loaded:
                        if oss_bucket is None:
                            _loaded[e] = torch.load(
                                output_dir / f"{e}epoch.pb",
                                map_location="cpu",
                            )
                        else:
                            buffer = BytesIO(
                                oss_bucket.get_object(os.path.join(pai_output_dir, f"{e}epoch.pb")).read())
                            _loaded[e] = torch.load(buffer)
                    states = _loaded[e]
                    if avg is None:
                        avg = states
                    else:
                        # Accumulated
                        for k in avg:
                            avg[k] = avg[k] + states[k]
                for k in avg:
                    if str(avg[k].dtype).startswith("torch.int"):
                        # For int type, not averaged, but only accumulated.
                        # e.g. BatchNorm.num_batches_tracked
                        # (If there are any cases that requires averaging
                        #  or the other reducing method, e.g. max/min, for integer type,
                        #  please report.)
                        pass
                    else:
                        avg[k] = avg[k] / n
                # 2.b. Save the ave model and create a symlink
                if oss_bucket is None:
                    torch.save(avg, op)
                else:
                    buffer = BytesIO()
                    torch.save(avg, buffer)
                    oss_bucket.put_object(os.path.join(pai_output_dir, f"{ph}.{cr}.ave_{n}best.{suffix}pb"),
                                          buffer.getvalue())
        # 3. *.*.ave.pb is a symlink to the max ave model
        if oss_bucket is None:
            op = output_dir / f"{ph}.{cr}.ave_{max(_nbests)}best.{suffix}pb"
            sym_op = output_dir / f"{ph}.{cr}.ave.{suffix}pb"
            if sym_op.is_symlink() or sym_op.exists():
                sym_op.unlink()
            sym_op.symlink_to(op.name)
    # Average or sum weights
    avg_state_dict = OrderedDict()
    for key in state_dicts[0].keys():
        tensors = [state_dict[key].cpu() for state_dict in state_dicts]
        # Check the type of the tensor
        if str(tensors[0].dtype).startswith("torch.int"):
            # Perform sum for integer tensors
            summed_tensor = sum(tensors)
            avg_state_dict[key] = summed_tensor
        else:
            # Perform average for other types of tensors
            stacked_tensors = torch.stack(tensors)
            avg_state_dict[key] = torch.mean(stacked_tensors, dim=0)
    torch.save({'state_dict': avg_state_dict}, os.path.join(output_dir, f"model.pt.avg{last_n}"))
    return avg_state_dict
funasr/train_utils/trainer.py
@@ -7,10 +7,11 @@
from contextlib import nullcontext
# from torch.utils.tensorboard import SummaryWriter
from tensorboardX import SummaryWriter
from pathlib import Path
from funasr.train_utils.device_funcs import to_device
from funasr.train_utils.recursive_op import recursive_average
from funasr.train_utils.average_nbest_models import average_checkpoints
class Trainer:
    """
@@ -66,10 +67,9 @@
        self.use_ddp = use_ddp
        self.use_fsdp = use_fsdp
        self.device = next(model.parameters()).device
        self.avg_nbest_model = kwargs.get("avg_nbest_model", 5)
        self.kwargs = kwargs
        
        if self.resume:
            self._resume_checkpoint(self.resume)
    
        try:
            rank = dist.get_rank()
@@ -102,9 +102,17 @@
        }
        # Create output directory if it does not exist
        os.makedirs(self.output_dir, exist_ok=True)
        filename = os.path.join(self.output_dir, f'model.e{epoch}.pb')
        filename = os.path.join(self.output_dir, f'model.pt.ep{epoch}')
        torch.save(state, filename)
        print(f'Checkpoint saved to {filename}')
        latest = Path(os.path.join(self.output_dir, f'model.pt'))
        try:
            latest.unlink()
        except:
            pass
        latest.symlink_to(filename)
    
    def _resume_checkpoint(self, resume_path):
        """
@@ -114,29 +122,50 @@
        Args:
            resume_path (str): The file path to the checkpoint to resume from.
        """
        if os.path.isfile(resume_path):
            checkpoint = torch.load(resume_path)
        ckpt = os.path.join(resume_path, "model.pt")
        if os.path.isfile(ckpt):
            checkpoint = torch.load(ckpt)
            self.start_epoch = checkpoint['epoch'] + 1
            self.model.load_state_dict(checkpoint['state_dict'])
            self.optim.load_state_dict(checkpoint['optimizer'])
            self.scheduler.load_state_dict(checkpoint['scheduler'])
            print(f"Checkpoint loaded successfully from '{resume_path}' at (epoch {checkpoint['epoch']})")
            print(f"Checkpoint loaded successfully from '{ckpt}'")
        else:
            print(f"No checkpoint found at '{resume_path}', starting from scratch")
            print(f"No checkpoint found at '{ckpt}', starting from scratch")
        if self.use_ddp or self.use_fsdp:
            dist.barrier()
        
    def run(self):
        """
        Starts the training process, iterating over epochs, training the model,
        and saving checkpoints at the end of each epoch.
        """
        if self.resume:
            self._resume_checkpoint(self.output_dir)
        for epoch in range(self.start_epoch, self.max_epoch + 1):
            self._train_epoch(epoch)
            # self._validate_epoch(epoch)
            self._validate_epoch(epoch)
            if self.rank == 0:
                self._save_checkpoint(epoch)
            self.scheduler.step()
            
            if self.use_ddp or self.use_fsdp:
                dist.barrier()
            self.scheduler.step()
        if self.rank == 0:
            average_checkpoints(self.output_dir, self.avg_nbest_model)
        if self.use_ddp or self.use_fsdp:
            dist.barrier()
        self.writer.close()
    
    def _train_epoch(self, epoch):
        """
@@ -157,8 +186,7 @@
        for batch_idx, batch in enumerate(self.dataloader_train):
            time1 = time.perf_counter()
            speed_stats["data_load"] = f"{time1-time5:0.3f}"
            # import pdb;
            # pdb.set_trace()
            batch = to_device(batch, self.device)
            
            my_context = self.model.no_sync if batch_idx % accum_grad != 0 else nullcontext
@@ -211,13 +239,12 @@
                speed_stats["optim_time"] = f"{time5 - time4:0.3f}"
    
                speed_stats["total_time"] = total_time
            # import pdb;
            # pdb.set_trace()
            pbar.update(1)
            if self.local_rank == 0:
                description = (
                    f"Epoch: {epoch + 1}/{self.max_epoch}, "
                    f"Epoch: {epoch}/{self.max_epoch}, "
                    f"step {batch_idx}/{len(self.dataloader_train)}, "
                    f"{speed_stats}, "
                    f"(loss: {loss.detach().cpu().item():.3f}), "
@@ -248,6 +275,50 @@
        """
        self.model.eval()
        with torch.no_grad():
            for data, target in self.dataloader_val:
                # Implement the model validation steps here
                pass
            pbar = tqdm(colour="red", desc=f"Training Epoch: {epoch + 1}", total=len(self.dataloader_val),
                        dynamic_ncols=True)
            speed_stats = {}
            time5 = time.perf_counter()
            for batch_idx, batch in enumerate(self.dataloader_val):
                time1 = time.perf_counter()
                speed_stats["data_load"] = f"{time1 - time5:0.3f}"
                batch = to_device(batch, self.device)
                time2 = time.perf_counter()
                retval = self.model(**batch)
                time3 = time.perf_counter()
                speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
                loss, stats, weight = retval
                stats = {k: v for k, v in stats.items() if v is not None}
                if self.use_ddp or self.use_fsdp:
                    # Apply weighted averaging for loss and stats
                    loss = (loss * weight.type(loss.dtype)).sum()
                    # if distributed, this method can also apply all_reduce()
                    stats, weight = recursive_average(stats, weight, distributed=True)
                    # Now weight is summation over all workers
                    loss /= weight
                    # Multiply world_size because DistributedDataParallel
                    # automatically normalizes the gradient by world_size.
                    loss *= self.world_size
                # Scale the loss since we're not updating for every mini-batch
                loss = loss
                time4 = time.perf_counter()
                pbar.update(1)
                if self.local_rank == 0:
                    description = (
                        f"validation: \nEpoch: {epoch}/{self.max_epoch}, "
                        f"step {batch_idx}/{len(self.dataloader_train)}, "
                        f"{speed_stats}, "
                        f"(loss: {loss.detach().cpu().item():.3f}), "
                        f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}"
                    )
                    pbar.set_description(description)
                    if self.writer:
                        self.writer.add_scalar('Loss/val', loss.item(),
                                               epoch*len(self.dataloader_train) + batch_idx)
                        for key, var in stats.items():
                            self.writer.add_scalar(f'{key}/val', var.item(),
                                                   epoch * len(self.dataloader_train) + batch_idx)
                        for key, var in speed_stats.items():
                            self.writer.add_scalar(f'{key}/val', eval(var),
                                                   epoch * len(self.dataloader_train) + batch_idx)