From 3eee773814c392e497557bbad501e0add4c8eca9 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期日, 09 六月 2024 02:11:42 +0800
Subject: [PATCH] fix bug

---
 funasr/datasets/audio_datasets/index_ds.py |    7 ++++++-
 1 files changed, 6 insertions(+), 1 deletions(-)

diff --git a/funasr/datasets/audio_datasets/index_ds.py b/funasr/datasets/audio_datasets/index_ds.py
index d26124b..385218a 100644
--- a/funasr/datasets/audio_datasets/index_ds.py
+++ b/funasr/datasets/audio_datasets/index_ds.py
@@ -21,6 +21,7 @@
         self.min_source_length = kwargs.get("min_source_length", 0)
         self.max_target_length = kwargs.get("max_target_length", 2048)
         self.min_target_length = kwargs.get("min_target_length", 0)
+        self.max_token_length = kwargs.get("max_token_length", 2200)
 
         is_training = kwargs.get("is_training", True)
         if not (path.endswith(".jsonl") or path.endswith(".json")):
@@ -34,7 +35,7 @@
             with open(path, encoding="utf-8") as fin:
                 file_list_all = fin.readlines()
 
-                num_per_slice = (len(file_list_all) - 1) // data_split_num + 1
+                num_per_slice = (len(file_list_all) - 1) // data_split_num + 1  # 16
                 file_list = file_list_all[
                     data_split_i * num_per_slice : (data_split_i + 1) * num_per_slice
                 ]
@@ -103,6 +104,10 @@
                             or target_len > self.max_target_length
                         ):
                             continue
+
+                        if (source_len + target_len) > self.max_token_length:
+                            continue
+
                         contents_i = {
                             "source": source,
                             "prompt": prompt,

--
Gitblit v1.9.1