From b1c186fd00fef54bcad3aa1d073a1a313642d641 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 08 五月 2024 00:31:29 +0800
Subject: [PATCH] Dev gzf exp (#1700)

---
 funasr/auto/auto_model.py |   29 ++++++++++++++++-------------
 1 files changed, 16 insertions(+), 13 deletions(-)

diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py
index 32fd560..577c328 100644
--- a/funasr/auto/auto_model.py
+++ b/funasr/auto/auto_model.py
@@ -26,6 +26,7 @@
 from funasr.train_utils.set_all_random_seed import set_all_random_seed
 from funasr.train_utils.load_pretrained_model import load_pretrained_model
 from funasr.utils import export_utils
+from funasr.utils import misc
 
 try:
     from funasr.models.campplus.utils import sv_chunk, postprocess, distribute_spk
@@ -35,14 +36,7 @@
 
 
 def prepare_data_iterator(data_in, input_len=None, data_type=None, key=None):
-    """
-
-    :param input:
-    :param input_len:
-    :param data_type:
-    :param frontend:
-    :return:
-    """
+    """ """
     data_list = []
     key_list = []
     filelist = [".scp", ".txt", ".json", ".jsonl", ".text"]
@@ -73,7 +67,8 @@
                     key_list.append(key)
         else:
             if key is None:
-                key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
+                # key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
+                key = misc.extract_filename_without_extension(data_in)
             data_list = [data_in]
             key_list = [key]
     elif isinstance(data_in, (list, tuple)):
@@ -90,10 +85,14 @@
         else:
             # [audio sample point, fbank, text]
             data_list = data_in
-            key_list = [
-                "rand_key_" + "".join(random.choice(chars) for _ in range(13))
-                for _ in range(len(data_in))
-            ]
+            key_list = []
+            for data_i in data_in:
+                if isinstance(data_i, str) and os.path.exists(data_i):
+                    key = misc.extract_filename_without_extension(data_i)
+                else:
+                    key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
+                key_list.append(key)
+
     else:  # raw text; audio sample point, fbank; bytes
         if isinstance(data_in, bytes):  # audio bytes
             data_in = load_bytes(data_in)
@@ -108,6 +107,10 @@
 class AutoModel:
 
     def __init__(self, **kwargs):
+
+        log_level = getattr(logging, kwargs.get("log_level", "INFO").upper())
+        logging.basicConfig(level=log_level)
+
         if not kwargs.get("disable_log", True):
             tables.print()
 

--
Gitblit v1.9.1