From cb8b09e085bdfb5599ccd1a862912bc0b7e4f41c Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 20 二月 2024 16:42:09 +0800
Subject: [PATCH] update
---
funasr/train_utils/trainer.py | 8 ++++++--
funasr/datasets/audio_datasets/preprocessor.py | 7 ++++---
2 files changed, 10 insertions(+), 5 deletions(-)
diff --git a/funasr/datasets/audio_datasets/preprocessor.py b/funasr/datasets/audio_datasets/preprocessor.py
index c2e27bf..a3ba3a5 100644
--- a/funasr/datasets/audio_datasets/preprocessor.py
+++ b/funasr/datasets/audio_datasets/preprocessor.py
@@ -26,9 +26,10 @@
return waveform
speed = random.choice(self.speed_perturb)
if speed != 1.0:
- waveform, _ = torchaudio.sox_effects.apply_effects_tensor(
- torch.tensor(waveform).view(1, -1), fs, [['speed', str(speed)], ['rate', str(fs)]])
- waveform = waveform.view(-1)
+ with torch.no_grad():
+ waveform, _ = torchaudio.sox_effects.apply_effects_tensor(
+ torch.tensor(waveform).view(1, -1), fs, [['speed', str(speed)], ['rate', str(fs)]])
+ waveform = waveform.view(-1)
return waveform
diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py
index 3cd61a1..cf997a4 100644
--- a/funasr/train_utils/trainer.py
+++ b/funasr/train_utils/trainer.py
@@ -273,8 +273,9 @@
speed_stats["total_time"] = total_time
- pbar.update(1)
+
if self.local_rank == 0:
+ pbar.update(1)
gpu_info = "GPU, memory: {:.3f} GB, " \
"{:.3f} GB, "\
"{:.3f} GB, "\
@@ -290,6 +291,7 @@
f"(loss: {loss.detach().cpu().item():.3f}), "
f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}"
f"{gpu_info}"
+ f"rank: {self.local_rank}"
)
pbar.set_description(description)
if self.writer:
@@ -344,14 +346,16 @@
loss = loss
time4 = time.perf_counter()
- pbar.update(1)
+
if self.local_rank == 0:
+ pbar.update(1)
description = (
f"validation epoch: {epoch}/{self.max_epoch}, "
f"step {batch_idx}/{len(self.dataloader_train)}, "
f"{speed_stats}, "
f"(loss: {loss.detach().cpu().item():.3f}), "
f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}"
+ f"rank: {self.local_rank}"
)
pbar.set_description(description)
if self.writer:
--
Gitblit v1.9.1