From 46e5beede954594983ccf01344793a76271f028f Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 07 五月 2024 22:12:05 +0800
Subject: [PATCH] decoding key

---
 funasr/bin/inference.py             |    3 ---
 funasr/models/sense_voice/model.py  |    7 +++++--
 funasr/bin/export.py                |    3 ---
 funasr/models/sense_voice/search.py |    2 ++
 funasr/auto/auto_model.py           |    4 ++++
 5 files changed, 11 insertions(+), 8 deletions(-)

diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py
index e77f04f..577c328 100644
--- a/funasr/auto/auto_model.py
+++ b/funasr/auto/auto_model.py
@@ -107,6 +107,10 @@
 class AutoModel:
 
     def __init__(self, **kwargs):
+
+        log_level = getattr(logging, kwargs.get("log_level", "INFO").upper())
+        logging.basicConfig(level=log_level)
+
         if not kwargs.get("disable_log", True):
             tables.print()
 
diff --git a/funasr/bin/export.py b/funasr/bin/export.py
index 6c9b49f..9d01401 100644
--- a/funasr/bin/export.py
+++ b/funasr/bin/export.py
@@ -17,9 +17,6 @@
             return cfg_item
 
     kwargs = to_plain_list(cfg)
-    log_level = getattr(logging, kwargs.get("log_level", "INFO").upper())
-
-    logging.basicConfig(level=log_level)
 
     if kwargs.get("debug", False):
         import pdb
diff --git a/funasr/bin/inference.py b/funasr/bin/inference.py
index 2a1b6aa..39ee5c0 100644
--- a/funasr/bin/inference.py
+++ b/funasr/bin/inference.py
@@ -16,9 +16,6 @@
             return cfg_item
 
     kwargs = to_plain_list(cfg)
-    log_level = getattr(logging, kwargs.get("log_level", "INFO").upper())
-
-    logging.basicConfig(level=log_level)
 
     if kwargs.get("debug", False):
         import pdb
diff --git a/funasr/models/sense_voice/model.py b/funasr/models/sense_voice/model.py
index dcf18fd..0230638 100644
--- a/funasr/models/sense_voice/model.py
+++ b/funasr/models/sense_voice/model.py
@@ -803,7 +803,10 @@
                 tokenizer=tokenizer,
             )
 
-            if len(kwargs.get("data_type", [])) > 1:
+            if (
+                isinstance(kwargs.get("data_type", None), (list, tuple))
+                and len(kwargs.get("data_type", [])) > 1
+            ):
                 audio_sample_list, text_token_int_list = audio_sample_list
                 text_token_int = text_token_int_list[0]
             else:
@@ -857,7 +860,7 @@
             ys_pad = torch.tensor(sos_int + text_token_int, dtype=torch.int64).to(kwargs["device"])[
                 None, :
             ]
-            ys_pad_lens = torch.tensor([len(text_token_int)], dtype=torch.int64).to(
+            ys_pad_lens = torch.tensor([len(sos_int + text_token_int)], dtype=torch.int64).to(
                 kwargs["device"]
             )[None, :]
             decoder_out = self.model.decoder(
diff --git a/funasr/models/sense_voice/search.py b/funasr/models/sense_voice/search.py
index 98d02db..694e569 100644
--- a/funasr/models/sense_voice/search.py
+++ b/funasr/models/sense_voice/search.py
@@ -370,6 +370,8 @@
             # post process of one iteration
             running_hyps = self.post_process(i, maxlen, maxlenratio, best, ended_hyps)
             # end detection
+            # if len(ended_hyps) > 0:
+            #     print(f"ended_hyps: {ended_hyps}")
             if maxlenratio == 0.0 and end_detect([h.asdict() for h in ended_hyps], i):
                 logging.info(f"end detected at {i}")
                 break

--
Gitblit v1.9.1