From 6f7e27eb7c2d0a7649ec8f14d167c8da8e29f906 Mon Sep 17 00:00:00 2001
From: jmwang66 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期二, 16 五月 2023 15:07:20 +0800
Subject: [PATCH] Merge pull request #518 from alibaba-damo-academy/dev_wjm2

---
 egs/aishell2/transformer/utils/compute_fbank.py |   24 +++++++++++++++++++++---
 1 files changed, 21 insertions(+), 3 deletions(-)

diff --git a/egs/aishell2/transformer/utils/compute_fbank.py b/egs/aishell2/transformer/utils/compute_fbank.py
index d03b5a8..9c3904f 100755
--- a/egs/aishell2/transformer/utils/compute_fbank.py
+++ b/egs/aishell2/transformer/utils/compute_fbank.py
@@ -14,7 +14,8 @@
                   frame_shift=10,
                   dither=0.0,
                   resample_rate=16000,
-                  speed=1.0):
+                  speed=1.0,
+                  window_type="hamming"):
 
     waveform, sample_rate = torchaudio.load(wav_file)
     if resample_rate != sample_rate:
@@ -33,7 +34,7 @@
                       frame_shift=frame_shift,
                       dither=dither,
                       energy_floor=0.0,
-                      window_type='hamming',
+                      window_type=window_type,
                       sample_frequency=resample_rate)
 
     return mat.numpy()
@@ -68,6 +69,13 @@
         help="feature dims",
     )
     parser.add_argument(
+        "--max-lengths",
+        "-m",
+        default=1500,
+        type=int,
+        help="max frame numbers",
+    )
+    parser.add_argument(
         "--sample-frequency",
         "-s",
         default=16000,
@@ -96,6 +104,13 @@
         required=True,
         type=str,
         help="output dir",
+    )
+    parser.add_argument(
+        "--window-type",
+        default="hamming",
+        required=False,
+        type=str,
+        help="window type"
     )
     return parser
 
@@ -131,10 +146,13 @@
                     fbank = compute_fbank(wav_file,
                                           num_mel_bins=args.dims,
                                           resample_rate=args.sample_frequency,
-                                          speed=float(speed)
+                                          speed=float(speed),
+                                          window_type=args.window_type
                                           )
                     feats_dims = fbank.shape[1]
                     feats_lens = fbank.shape[0]
+                    if feats_lens >= args.max_lengths:
+                        continue
                     txt_lens = len(txt)
                     if speed == "1.0":
                         wav_id_sp = wav_id

--
Gitblit v1.9.1