From 824377d2aae11dc9ebbde871e3b23a0e0cadc7af Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 17 四月 2024 16:59:29 +0800
Subject: [PATCH] Dev gzf exp (#1626)

---
 funasr/datasets/audio_datasets/index_ds.py |   10 ++++++++--
 1 files changed, 8 insertions(+), 2 deletions(-)

diff --git a/funasr/datasets/audio_datasets/index_ds.py b/funasr/datasets/audio_datasets/index_ds.py
index 5396c8a..53419e8 100644
--- a/funasr/datasets/audio_datasets/index_ds.py
+++ b/funasr/datasets/audio_datasets/index_ds.py
@@ -76,7 +76,10 @@
     
     def __init__(self, path: str, **kwargs):
         super().__init__()
-        
+        self.max_source_length = kwargs.get("max_source_length", 2048)
+        self.min_source_length = kwargs.get("min_source_length", 0)
+        self.max_target_length = kwargs.get("max_target_length", 2048)
+        self.min_target_length = kwargs.get("min_target_length", 0)
         if isinstance(path, (list, tuple)): # wav.scp, text.txt/text.trans
             from funasr.datasets.audio_datasets.scp2jsonl import gen_jsonl_from_wav_text_list
             jsonl_outdir = os.path.dirname(path[0])
@@ -101,7 +104,10 @@
                     target_len = data.get("target_len", 0)
                     if "aishell" in source:
                         target = target.replace(" ", "")
-
+                    if source_len < self.min_source_length or source_len > self.max_source_length:
+                        continue
+                    if target_len < self.min_target_length or target_len > self.max_target_length:
+                        continue
                     contents_i = {"source": source,
                                  "prompt": prompt,
                                  "target": target,

--
Gitblit v1.9.1