From 4adb76a6edbca93aae7caa83382e764d7b058f07 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 08 五月 2024 19:21:58 +0800
Subject: [PATCH] Dev gzf exp (#1707)

---
 funasr/auto/auto_model.py |   49 ++++++++++++++++++++++++++++---------------------
 1 files changed, 28 insertions(+), 21 deletions(-)

diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py
index 32fd560..97eb325 100644
--- a/funasr/auto/auto_model.py
+++ b/funasr/auto/auto_model.py
@@ -26,6 +26,7 @@
 from funasr.train_utils.set_all_random_seed import set_all_random_seed
 from funasr.train_utils.load_pretrained_model import load_pretrained_model
 from funasr.utils import export_utils
+from funasr.utils import misc
 
 try:
     from funasr.models.campplus.utils import sv_chunk, postprocess, distribute_spk
@@ -35,14 +36,7 @@
 
 
 def prepare_data_iterator(data_in, input_len=None, data_type=None, key=None):
-    """
-
-    :param input:
-    :param input_len:
-    :param data_type:
-    :param frontend:
-    :return:
-    """
+    """ """
     data_list = []
     key_list = []
     filelist = [".scp", ".txt", ".json", ".jsonl", ".text"]
@@ -73,7 +67,8 @@
                     key_list.append(key)
         else:
             if key is None:
-                key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
+                # key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
+                key = misc.extract_filename_without_extension(data_in)
             data_list = [data_in]
             key_list = [key]
     elif isinstance(data_in, (list, tuple)):
@@ -90,10 +85,14 @@
         else:
             # [audio sample point, fbank, text]
             data_list = data_in
-            key_list = [
-                "rand_key_" + "".join(random.choice(chars) for _ in range(13))
-                for _ in range(len(data_in))
-            ]
+            key_list = []
+            for data_i in data_in:
+                if isinstance(data_i, str) and os.path.exists(data_i):
+                    key = misc.extract_filename_without_extension(data_i)
+                else:
+                    key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
+                key_list.append(key)
+
     else:  # raw text; audio sample point, fbank; bytes
         if isinstance(data_in, bytes):  # audio bytes
             data_in = load_bytes(data_in)
@@ -108,6 +107,10 @@
 class AutoModel:
 
     def __init__(self, **kwargs):
+
+        log_level = getattr(logging, kwargs.get("log_level", "INFO").upper())
+        logging.basicConfig(level=log_level)
+
         if not kwargs.get("disable_log", True):
             tables.print()
 
@@ -361,7 +364,6 @@
             if len(sorted_data) > 0 and len(sorted_data[0]) > 0:
                 batch_size = max(batch_size, sorted_data[0][0][1] - sorted_data[0][0][0])
 
-            batch_size_ms_cum = 0
             beg_idx = 0
             beg_asr_total = time.time()
             time_speech_total_per_sample = speech_lengths / 16000
@@ -370,19 +372,22 @@
             # pbar_sample = tqdm(colour="blue", total=n, dynamic_ncols=True)
 
             all_segments = []
+            max_len_in_batch = 0
+            end_idx = 1
             for j, _ in enumerate(range(0, n)):
                 # pbar_sample.update(1)
-                batch_size_ms_cum += sorted_data[j][0][1] - sorted_data[j][0][0]
+                sample_length = sorted_data[j][0][1] - sorted_data[j][0][0]
+                potential_batch_length = max(max_len_in_batch, sample_length) * (j + 1 - beg_idx)
+                # batch_size_ms_cum += sorted_data[j][0][1] - sorted_data[j][0][0]
                 if (
                     j < n - 1
-                    and (batch_size_ms_cum + sorted_data[j + 1][0][1] - sorted_data[j + 1][0][0])
-                    < batch_size
-                    and (sorted_data[j + 1][0][1] - sorted_data[j + 1][0][0])
-                    < batch_size_threshold_ms
+                    and sample_length < batch_size_threshold_ms
+                    and potential_batch_length < batch_size
                 ):
+                    max_len_in_batch = max(max_len_in_batch, sample_length)
+                    end_idx += 1
                     continue
-                batch_size_ms_cum = 0
-                end_idx = j + 1
+
                 speech_j, speech_lengths_j = slice_padding_audio_samples(
                     speech, speech_lengths, sorted_data[beg_idx:end_idx]
                 )
@@ -407,6 +412,8 @@
                         )
                         results[_b]["spk_embedding"] = spk_res[0]["spk_embedding"]
                 beg_idx = end_idx
+                end_idx += 1
+                max_len_in_batch = sample_length
                 if len(results) < 1:
                     continue
                 results_sorted.extend(results)

--
Gitblit v1.9.1