From b76af7be8cd7428f19ec0ba9a7fd811148fbc358 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期日, 28 四月 2024 21:18:45 +0800
Subject: [PATCH] Merge branch 'dev_gzf_exp' of github.com:alibaba-damo-academy/FunASR into dev_gzf_exp merge

---
 funasr/datasets/sense_voice_datasets/datasets.py |  132 ++++++++++++++++++++------------
 funasr/train_utils/trainer.py                    |   42 +++++++--
 funasr/bin/train.py                              |   15 ++-
 docs/images/wechat.png                           |    0 
 4 files changed, 123 insertions(+), 66 deletions(-)

diff --git a/docs/images/wechat.png b/docs/images/wechat.png
index 6d19842..ac8fa38 100644
--- a/docs/images/wechat.png
+++ b/docs/images/wechat.png
Binary files differ
diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index 448e464..97516eb 100644
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -13,7 +13,7 @@
 
 from contextlib import nullcontext
 import torch.distributed as dist
-from collections.abc import Sequence
+
 from omegaconf import DictConfig, OmegaConf
 from torch.cuda.amp import autocast, GradScaler
 from torch.nn.parallel import DistributedDataParallel as DDP
@@ -99,7 +99,7 @@
     if freeze_param is not None:
         if "," in freeze_param:
             freeze_param = eval(freeze_param)
-        if not isinstance(freeze_param, Sequence):
+        if not isinstance(freeze_param, (list, tuple)):
             freeze_param = (freeze_param,)
         logging.info("freeze_param is not None: %s", freeze_param)
         for t in freeze_param:
@@ -193,7 +193,7 @@
     try:
         from tensorboardX import SummaryWriter
 
-        writer = SummaryWriter(tensorboard_dir) if trainer.rank == 0 else None
+        writer = SummaryWriter(tensorboard_dir)  # if trainer.rank == 0 else None
     except:
         writer = None
 
@@ -206,6 +206,7 @@
                 epoch, data_split_i=data_split_i, start_step=trainer.start_step
             )
             trainer.start_step = 0
+
             trainer.train_epoch(
                 model=model,
                 optim=optim,
@@ -222,11 +223,13 @@
             torch.cuda.empty_cache()
 
         trainer.validate_epoch(
-            model=model, dataloader_val=dataloader_val, epoch=epoch, writer=writer
+            model=model, dataloader_val=dataloader_val, epoch=epoch + 1, writer=writer
         )
         scheduler.step()
-
-        trainer.save_checkpoint(epoch, model=model, optim=optim, scheduler=scheduler, scaler=scaler)
+        trainer.step_in_epoch = 0
+        trainer.save_checkpoint(
+            epoch + 1, model=model, optim=optim, scheduler=scheduler, scaler=scaler
+        )
 
         time2 = time.perf_counter()
         time_escaped = (time2 - time1) / 3600.0
diff --git a/funasr/datasets/sense_voice_datasets/datasets.py b/funasr/datasets/sense_voice_datasets/datasets.py
index 226342c..1d269dd 100644
--- a/funasr/datasets/sense_voice_datasets/datasets.py
+++ b/funasr/datasets/sense_voice_datasets/datasets.py
@@ -51,6 +51,7 @@
         self.batch_size = kwargs.get("batch_size")
         self.batch_type = kwargs.get("batch_type")
         self.prompt_ids_len = 0
+        self.retry = kwargs.get("retry", 5)
 
     def get_source_len(self, index):
         item = self.index_ds[index]
@@ -64,59 +65,75 @@
         return len(self.index_ds)
 
     def __getitem__(self, index):
-        item = self.index_ds[index]
         # import pdb;
         # pdb.set_trace()
-        source = item["source"]
-        data_src = load_audio_text_image_video(source, fs=self.fs)
-        if self.preprocessor_speech:
-            data_src = self.preprocessor_speech(data_src, fs=self.fs)
-        speech, speech_lengths = extract_fbank(
-            data_src, data_type=self.data_type, frontend=self.frontend, is_final=True
-        )  # speech: [b, T, d]
 
-        if speech_lengths > self.batch_size:
-            return None
-        speech = speech.permute(0, 2, 1)
-        target = item["target"]
-        if self.preprocessor_text:
-            target = self.preprocessor_text(target)
+        output = None
+        for idx in range(self.retry):
+            if idx == 0:
+                index_cur = index
+            else:
+                if index <= self.retry:
+                    index_cur = index + idx
+                else:
+                    index_cur = torch.randint(0, index, ()).item()
 
-        task = item.get("prompt", "<|ASR|>")
-        text_language = item.get("text_language", "<|zh|>")
+            item = self.index_ds[index_cur]
 
-        prompt = f"{self.sos}{task}{text_language}"
-        prompt_ids = self.tokenizer.encode(prompt, allowed_special="all")
-        prompt_ids_len = len(prompt_ids) - 1  # [sos, task]
-        self.prompt_ids_len = prompt_ids_len
+            source = item["source"]
+            data_src = load_audio_text_image_video(source, fs=self.fs)
+            if self.preprocessor_speech:
+                data_src = self.preprocessor_speech(data_src, fs=self.fs)
+            speech, speech_lengths = extract_fbank(
+                data_src, data_type=self.data_type, frontend=self.frontend, is_final=True
+            )  # speech: [b, T, d]
 
-        target_ids = self.tokenizer.encode(target, allowed_special="all")
-        target_ids_len = len(target_ids) + 1  # [lid, text]
-        if target_ids_len > 200:
-            return None
+            if speech_lengths > self.batch_size:
+                continue
+            speech = speech.permute(0, 2, 1)
+            target = item["target"]
+            if self.preprocessor_text:
+                target = self.preprocessor_text(target)
 
-        eos = self.tokenizer.encode(self.eos, allowed_special="all")  # [eos]
+            task = item.get("prompt", "<|ASR|>")
+            text_language = item.get("text_language", "<|zh|>")
 
-        ids = prompt_ids + target_ids + eos
-        ids_lengths = len(ids)
+            prompt = f"{self.sos}{task}{text_language}"
+            prompt_ids = self.tokenizer.encode(prompt, allowed_special="all")
+            prompt_ids_len = len(prompt_ids) - 1  # [sos, task]
+            self.prompt_ids_len = prompt_ids_len
 
-        text = torch.tensor(ids, dtype=torch.int64)
-        text_lengths = torch.tensor([ids_lengths], dtype=torch.int32)
+            target_ids = self.tokenizer.encode(target, allowed_special="all")
+            target_ids_len = len(target_ids) + 1  # [lid, text]
+            if target_ids_len > 200:
+                continue
 
-        target_mask = (
-            [0] * (prompt_ids_len) + [1] * (target_ids_len) + [1]
-        )  # [sos, task, lid, text, eos]: [0, 0, 1, 1, 1]
-        target_mask_lengths = len(target_mask)
-        target_mask = torch.tensor(target_mask, dtype=torch.float32)
-        target_mask_lengths = torch.tensor([target_mask_lengths], dtype=torch.int32)
-        return {
-            "speech": speech[0, :, :],
-            "speech_lengths": speech_lengths,
-            "text": text,
-            "text_lengths": text_lengths,
-            "target_mask": target_mask,
-            "target_mask_lengths": target_mask_lengths,
-        }
+            eos = self.tokenizer.encode(self.eos, allowed_special="all")  # [eos]
+
+            ids = prompt_ids + target_ids + eos
+            ids_lengths = len(ids)
+
+            text = torch.tensor(ids, dtype=torch.int64)
+            text_lengths = torch.tensor([ids_lengths], dtype=torch.int32)
+
+            target_mask = (
+                [0] * (prompt_ids_len) + [1] * (target_ids_len) + [1]
+            )  # [sos, task, lid, text, eos]: [0, 0, 1, 1, 1]
+            target_mask_lengths = len(target_mask)
+            target_mask = torch.tensor(target_mask, dtype=torch.float32)
+            target_mask_lengths = torch.tensor([target_mask_lengths], dtype=torch.int32)
+
+            output = {
+                "speech": speech[0, :, :],
+                "speech_lengths": speech_lengths,
+                "text": text,
+                "text_lengths": text_lengths,
+                "target_mask": target_mask,
+                "target_mask_lengths": target_mask_lengths,
+            }
+            break
+
+        return output
 
     def collator(self, samples: list = None):
         outputs = {}
@@ -129,13 +146,30 @@
                 outputs[key].append(sample[key])
 
         if len(outputs) < 1:
-            logging.info(f"ERROR: data is empty!")
+            logging.error(f"ERROR: data is empty!")
             outputs = {
-                "speech": torch.rand((10, 128), dtype=torch.float32),
-                "speech_lengths": torch.tensor([10], dtype=torch.int32),
-                "text": torch.tensor([58836], dtype=torch.int32),
-                "text_lengths": torch.tensor([1], dtype=torch.int32),
-                "target_mask": torch.tensor([[0] * (self.prompt_ids_len) + [1] * (1) + [1]]),
+                "speech": torch.rand((10, 128), dtype=torch.float32)[None, :, :],
+                "speech_lengths": torch.tensor(
+                    [
+                        10,
+                    ],
+                    dtype=torch.int32,
+                )[:, None],
+                "text": torch.tensor(
+                    [
+                        58836,
+                    ],
+                    dtype=torch.int32,
+                )[None, :],
+                "text_lengths": torch.tensor(
+                    [
+                        1,
+                    ],
+                    dtype=torch.int32,
+                )[:, None],
+                "target_mask": torch.tensor([[0] * (self.prompt_ids_len) + [1] * (1) + [1]])[
+                    None, :
+                ],
             }
             return outputs
 
diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py
index 66f8778..e86420c 100644
--- a/funasr/train_utils/trainer.py
+++ b/funasr/train_utils/trainer.py
@@ -116,6 +116,7 @@
         self.reset_gpu_cache = kwargs.get("reset_gpu_cache", False)
         self.start_data_split_i = 0
         self.start_step = 0
+        self.step_in_epoch = 0
         self.use_wandb = kwargs.get("use_wandb", False)
         if self.use_wandb:
             wandb.login(key=kwargs.get("wandb_token"))
@@ -137,6 +138,8 @@
         optim=None,
         scheduler=None,
         scaler=None,
+        step_in_epoch=None,
+        **kwargs,
     ):
         """
         Saves a checkpoint containing the model's state, the optimizer's state,
@@ -147,6 +150,7 @@
             epoch (int): The epoch number at which the checkpoint is being saved.
         """
 
+        step_in_epoch = None if step is None else step_in_epoch
         if self.rank == 0:
             logging.info(f"Save checkpoint: {epoch}, rank: {self.local_rank}\n")
             # self.step_or_epoch += 1
@@ -161,7 +165,12 @@
                 "best_step_or_epoch": self.best_step_or_epoch,
                 "avg_keep_nbest_models_type": self.avg_keep_nbest_models_type,
                 "step": step,
+                "step_in_epoch": step_in_epoch,
+                "data_split_i": kwargs.get("data_split_i", 0),
+                "data_split_num": kwargs.get("data_split_num", 1),
+                "batch_total": self.batch_total,
             }
+            step = step_in_epoch
             if hasattr(model, "module"):
                 state["state_dict"] = model.module.state_dict()
 
@@ -195,7 +204,7 @@
                     )
                 else:
                     logging.info(
-                        f"No improvement in acc: {self.val_acc_step_or_eoch[ckpt_name]:.4f} < {self.val_acc_step_or_eoch[self.best_step_or_epoch]:.4f}"
+                        f"No improvement in acc: {self.val_acc_step_or_eoch[ckpt_name]:.4f} < {self.val_acc_step_or_eoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}"
                     )
             elif self.avg_keep_nbest_models_type == "loss":
                 if (
@@ -210,7 +219,7 @@
                     )
                 else:
                     logging.info(
-                        f"No improvement in loss: {self.val_loss_step_or_eoch[ckpt_name]:.4f} > {self.val_loss_step_or_eoch[self.best_step_or_epoch]:.4f}"
+                        f"No improvement in loss: {self.val_loss_step_or_eoch[ckpt_name]:.4f} > {self.val_loss_step_or_eoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}"
                     )
             else:
                 print("Undo")
@@ -251,7 +260,7 @@
             ckpt = os.path.join(self.output_dir, "model.pt")
             if os.path.isfile(ckpt):
                 checkpoint = torch.load(ckpt, map_location="cpu")
-                self.start_epoch = checkpoint["epoch"] + 1
+                self.start_epoch = checkpoint["epoch"]
                 # self.model.load_state_dict(checkpoint['state_dict'])
                 src_state = checkpoint["state_dict"]
                 dst_state = model.state_dict()
@@ -288,11 +297,15 @@
                     checkpoint["best_step_or_epoch"] if "best_step_or_epoch" in checkpoint else ""
                 )
                 self.start_data_split_i = (
-                    checkpoint["start_data_split_i"] if "start_data_split_i" in checkpoint else 0
+                    checkpoint["data_split_i"] if "data_split_i" in checkpoint else 0
                 )
                 self.batch_total = checkpoint["batch_total"] if "batch_total" in checkpoint else 0
                 self.start_step = checkpoint["step"] if "step" in checkpoint else 0
                 self.start_step = 0 if self.start_step is None else self.start_step
+                self.step_in_epoch = (
+                    checkpoint["step_in_epoch"] if "step_in_epoch" in checkpoint else 0
+                )
+                self.step_in_epoch = 0 if self.step_in_epoch is None else self.step_in_epoch
 
                 model.to(self.device)
                 print(f"Checkpoint loaded successfully from '{ckpt}'")
@@ -321,7 +334,7 @@
         """
         if self.use_ddp or self.use_fsdp:
             dist.barrier()
-        logging.info(f"Train epoch: {epoch}, rank: {self.local_rank}\n")
+        logging.info(f"Train epoch: {epoch}, rank: {self.rank}\n")
         model.train()
 
         # Set the number of steps for gradient accumulation
@@ -341,6 +354,7 @@
                 if iterator_stop > 0:
                     break
             self.batch_total += 1
+            self.step_in_epoch += 1
             time1 = time.perf_counter()
             speed_stats["data_load"] = f"{time1-time_beg:0.3f}"
 
@@ -443,6 +457,7 @@
                 self.log(
                     epoch,
                     batch_idx,
+                    step_in_epoch=self.step_in_epoch,
                     batch_num_epoch=batch_num_epoch,
                     lr=lr,
                     loss=loss.detach().cpu().item(),
@@ -454,16 +469,17 @@
                     data_split_num=kwargs.get("data_split_num", 1),
                 )
 
-            if (batch_idx + 1) % self.validate_interval == 0:
+            if self.step_in_epoch % self.validate_interval == 0:
                 self.validate_epoch(
                     model=model,
                     dataloader_val=dataloader_val,
                     epoch=epoch,
                     writer=writer,
                     step=batch_idx + 1,
+                    step_in_epoch=self.step_in_epoch,
                 )
 
-            if (batch_idx + 1) % self.save_checkpoint_interval == 0:
+            if self.step_in_epoch % self.save_checkpoint_interval == 0:
                 self.save_checkpoint(
                     epoch,
                     model=model,
@@ -471,6 +487,9 @@
                     scheduler=scheduler,
                     scaler=scaler,
                     step=batch_idx + 1,
+                    step_in_epoch=self.step_in_epoch,
+                    data_split_i=kwargs.get("data_split_i", 0),
+                    data_split_num=kwargs.get("data_split_num", 1),
                 )
 
             time_beg = time.perf_counter()
@@ -500,7 +519,7 @@
         """
         if self.use_ddp or self.use_fsdp:
             dist.barrier()
-        logging.info(f"Validate epoch: {epoch}, rank: {self.local_rank}\n")
+        logging.info(f"Validate epoch: {epoch}, rank: {self.rank}\n")
         model.eval()
 
         with torch.no_grad():
@@ -578,10 +597,10 @@
                     iterator_stop.fill_(1)
                     dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
 
-        if kwargs.get("step", None) is None:
+        if kwargs.get("step_in_epoch", None) is None:
             ckpt_name = f"model.pt.ep{epoch}"
         else:
-            ckpt_name = f'model.pt.ep{epoch}.{kwargs.get("step")}'
+            ckpt_name = f'model.pt.ep{epoch}.{kwargs.get("step_in_epoch")}'
         self.val_acc_step_or_eoch[ckpt_name] = self.val_acc_avg
         self.val_loss_step_or_eoch[ckpt_name] = self.val_loss_avg
         model.train()
@@ -594,6 +613,7 @@
         self,
         epoch=0,
         batch_idx=0,
+        step_in_epoch=0,
         batch_num_epoch=-1,
         lr=0.0,
         loss=0.0,
@@ -627,7 +647,7 @@
                 f"rank: {self.rank}, "
                 f"epoch: {epoch}/{self.max_epoch}, "
                 f"data_slice: {data_split_i}/{data_split_num}, "
-                f"step: {batch_idx + 1}/{batch_num_epoch}, total step: {self.batch_total}, "
+                f"step_in_slice: {batch_idx + 1}/{batch_num_epoch}, step_in_epoch: {step_in_epoch}, total step: {self.batch_total}, "
                 f"(loss_avg_rank: {loss:.3f}), "
                 f"(loss_avg_epoch: {loss_avg_epoch:.3f}), "
                 f"(ppl_avg_epoch: {math.exp(loss_avg_epoch):.3e}), "

--
Gitblit v1.9.1