From cdf117b9746fdb72c6d0a2aa1ada4e1a131895ec Mon Sep 17 00:00:00 2001
From: aky15 <ankeyuthu@gmail.com>
Date: 星期二, 27 六月 2023 09:59:50 +0800
Subject: [PATCH] bug fix (#667)

---
 egs/wenetspeech/conformer/local/process_opus.py |   46 ++++++++++++++++++++++------------------------
 1 files changed, 22 insertions(+), 24 deletions(-)

diff --git a/egs/wenetspeech/conformer/local/process_opus.py b/egs/wenetspeech/conformer/local/process_opus.py
index 9f71eb1..581953f 100755
--- a/egs/wenetspeech/conformer/local/process_opus.py
+++ b/egs/wenetspeech/conformer/local/process_opus.py
@@ -16,14 +16,15 @@
 
 # usage: python3 process_opus.py wav.scp segments output_wav.scp
 
-from pydub import AudioSegment
-import sys
 import os
+import sys
+
+import torchaudio
 
 
 def read_file(wav_scp, segments):
     wav_scp_dict = {}
-    with open(wav_scp, 'r', encoding='UTF-8') as fin:
+    with open(wav_scp, "r", encoding="UTF-8") as fin:
         for line_str in fin:
             wav_id, path = line_str.strip().split()
             wav_scp_dict[wav_id] = path
@@ -32,7 +33,7 @@
     seg_path_list = []
     start_time_list = []
     end_time_list = []
-    with open(segments, 'r', encoding='UTF-8') as fin:
+    with open(segments, "r", encoding="UTF-8") as fin:
         for line_str in fin:
             arr = line_str.strip().split()
             assert len(arr) == 4
@@ -44,31 +45,28 @@
 
 
 # TODO(Qijie): Fix the process logic
-def output(output_wav_scp, utt_list, seg_path_list, start_time_list,
-           end_time_list):
+def output(output_wav_scp, utt_list, seg_path_list, start_time_list, end_time_list):
     num_utts = len(utt_list)
     step = int(num_utts * 0.01)
-    with open(output_wav_scp, 'w', encoding='UTF-8') as fout:
+    with open(output_wav_scp, "w", encoding="UTF-8") as fout:
         previous_wav_path = ""
         for i in range(num_utts):
             utt_id = utt_list[i]
             current_wav_path = seg_path_list[i]
-            output_dir = (os.path.dirname(current_wav_path)) \
-                .replace("audio", 'audio_seg')
-            seg_wav_path = os.path.join(output_dir, utt_id + '.wav')
+            output_dir = (os.path.dirname(current_wav_path)).replace(
+                "audio", "audio_seg"
+            )
+            seg_wav_path = os.path.join(output_dir, utt_id + ".wav")
 
-            # if not os.path.exists(output_dir):
-            #     os.makedirs(output_dir)
-
+            os.makedirs(output_dir, exist_ok=True)
             if current_wav_path != previous_wav_path:
-                source_wav = AudioSegment.from_file(current_wav_path)
+                waveform, sample_rate = torchaudio.load(current_wav_path)
             previous_wav_path = current_wav_path
 
-            start = int(start_time_list[i] * 1000)
-            end = int(end_time_list[i] * 1000)
-            target_audio = source_wav[start:end].set_frame_rate(16000) \
-                .set_sample_width(2)
-            target_audio.export(seg_wav_path, format="wav")
+            start = int(start_time_list[i] * sample_rate)
+            end = int(end_time_list[i] * sample_rate)
+            target_audio = waveform[:, start:end]
+            torchaudio.save(seg_wav_path, target_audio, sample_rate)
 
             fout.write("{} {}\n".format(utt_id, seg_wav_path))
             if i % step == 0:
@@ -80,11 +78,11 @@
     segments = sys.argv[2]
     output_wav_scp = sys.argv[3]
 
-    utt_list, seg_path_list, start_time_list, end_time_list \
-        = read_file(wav_scp, segments)
-    output(output_wav_scp, utt_list, seg_path_list, start_time_list,
-           end_time_list)
+    utt_list, seg_path_list, start_time_list, end_time_list = read_file(
+        wav_scp, segments
+    )
+    output(output_wav_scp, utt_list, seg_path_list, start_time_list, end_time_list)
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     main()

--
Gitblit v1.9.1