游雁
2024-02-20 cb8b09e085bdfb5599ccd1a862912bc0b7e4f41c
update
2个文件已修改
15 ■■■■■ 已修改文件
funasr/datasets/audio_datasets/preprocessor.py 7 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/train_utils/trainer.py 8 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/audio_datasets/preprocessor.py
@@ -26,9 +26,10 @@
            return waveform
        speed = random.choice(self.speed_perturb)
        if speed != 1.0:
            waveform, _ = torchaudio.sox_effects.apply_effects_tensor(
                torch.tensor(waveform).view(1, -1), fs, [['speed', str(speed)], ['rate', str(fs)]])
            waveform = waveform.view(-1)
            with torch.no_grad():
                waveform, _ = torchaudio.sox_effects.apply_effects_tensor(
                    torch.tensor(waveform).view(1, -1), fs, [['speed', str(speed)], ['rate', str(fs)]])
                waveform = waveform.view(-1)
            
        return waveform
funasr/train_utils/trainer.py
@@ -273,8 +273,9 @@
                speed_stats["total_time"] = total_time
            pbar.update(1)
            if self.local_rank == 0:
                pbar.update(1)
                gpu_info = "GPU, memory: {:.3f} GB, " \
                           "{:.3f} GB, "\
                           "{:.3f} GB, "\
@@ -290,6 +291,7 @@
                    f"(loss: {loss.detach().cpu().item():.3f}), "
                    f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}"
                    f"{gpu_info}"
                    f"rank: {self.local_rank}"
                )
                pbar.set_description(description)
                if self.writer:
@@ -344,14 +346,16 @@
                loss = loss
                time4 = time.perf_counter()
                pbar.update(1)
                if self.local_rank == 0:
                    pbar.update(1)
                    description = (
                        f"validation epoch: {epoch}/{self.max_epoch}, "
                        f"step {batch_idx}/{len(self.dataloader_train)}, "
                        f"{speed_stats}, "
                        f"(loss: {loss.detach().cpu().item():.3f}), "
                        f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}"
                        f"rank: {self.local_rank}"
                    )
                    pbar.set_description(description)
                    if self.writer: