From 28ccfbfc51068a663a80764e14074df5edf2b5ba Mon Sep 17 00:00:00 2001
From: kongdeqiang <kongdeqiang960204@163.com>
Date: 星期五, 13 三月 2026 17:41:41 +0800
Subject: [PATCH] 提交
---
funasr/auto/auto_model.py | 51 ++++++++++++++++++++++++++++++++++++++++-----------
1 files changed, 40 insertions(+), 11 deletions(-)
diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py
index 9f5f4fb..a864dad 100644
--- a/funasr/auto/auto_model.py
+++ b/funasr/auto/auto_model.py
@@ -147,13 +147,16 @@
# if spk_model is not None, build spk model else None
spk_model = kwargs.get("spk_model", None)
spk_kwargs = {} if kwargs.get("spk_kwargs", {}) is None else kwargs.get("spk_kwargs", {})
+ cb_kwargs = (
+ {} if spk_kwargs.get("cb_kwargs", {}) is None else spk_kwargs.get("cb_kwargs", {})
+ )
if spk_model is not None:
logging.info("Building SPK model.")
spk_kwargs["model"] = spk_model
spk_kwargs["model_revision"] = kwargs.get("spk_model_revision", "master")
spk_kwargs["device"] = kwargs["device"]
spk_model, spk_kwargs = self.build_model(**spk_kwargs)
- self.cb_model = ClusterBackend().to(kwargs["device"])
+ self.cb_model = ClusterBackend(**cb_kwargs).to(kwargs["device"])
spk_mode = kwargs.get("spk_mode", "punc_segment")
if spk_mode not in ["default", "vad_segment", "punc_segment"]:
logging.error("spk_mode should be one of default, vad_segment and punc_segment.")
@@ -179,7 +182,10 @@
set_all_random_seed(kwargs.get("seed", 0))
device = kwargs.get("device", "cuda")
- if not torch.cuda.is_available() or kwargs.get("ngpu", 1) == 0:
+ if ((device =="cuda" and not torch.cuda.is_available())
+ or (device == "xpu" and not torch.xpu.is_available())
+ or (device == "mps" and not torch.backends.mps.is_available())
+ or kwargs.get("ngpu", 1) == 0):
device = "cpu"
kwargs["batch_size"] = 1
kwargs["device"] = device
@@ -199,6 +205,7 @@
tokenizers_build = []
vocab_sizes = []
token_lists = []
+
### === only for kws ===
token_list_files = kwargs.get("token_lists", [])
seg_dicts = kwargs.get("seg_dicts", [])
@@ -213,9 +220,9 @@
### === only for kws ===
if len(token_list_files) > 1:
- tokenizer_conf.token_list = token_list_files[i]
+ tokenizer_conf["token_list"] = token_list_files[i]
if len(seg_dicts) > 1:
- tokenizer_conf.seg_dict = seg_dicts[i]
+ tokenizer_conf["seg_dict"] = seg_dicts[i]
### === only for kws ===
tokenizer = tokenizer_class(**tokenizer_conf)
@@ -228,8 +235,8 @@
if token_list is not None:
vocab_size = len(token_list)
- if vocab_size == -1 and hasattr(tokenizer, "get_vocab_size"):
- vocab_size = tokenizer.get_vocab_size()
+ if vocab_size == -1 and hasattr(tokenizer, "get_vocab_size"):
+ vocab_size = tokenizer.get_vocab_size()
token_lists.append(token_list)
vocab_sizes.append(vocab_size)
@@ -294,14 +301,27 @@
res = self.model(*args, kwargs)
return res
- def generate(self, input, input_len=None, **cfg):
+ def generate(self, input, input_len=None, progress_callback=None, **cfg):
if self.vad_model is None:
- return self.inference(input, input_len=input_len, **cfg)
+ return self.inference(
+ input, input_len=input_len, progress_callback=progress_callback, **cfg
+ )
else:
- return self.inference_with_vad(input, input_len=input_len, **cfg)
+ return self.inference_with_vad(
+ input, input_len=input_len, progress_callback=progress_callback, **cfg
+ )
- def inference(self, input, input_len=None, model=None, kwargs=None, key=None, **cfg):
+ def inference(
+ self,
+ input,
+ input_len=None,
+ model=None,
+ kwargs=None,
+ key=None,
+ progress_callback=None,
+ **cfg,
+ ):
kwargs = self.kwargs if kwargs is None else kwargs
if "cache" in kwargs:
kwargs.pop("cache")
@@ -358,13 +378,22 @@
if pbar:
pbar.update(end_idx - beg_idx)
pbar.set_description(description)
+ if progress_callback:
+ try:
+ progress_callback(end_idx, num_samples)
+ except Exception as e:
+ logging.error(f"progress_callback error: {e}")
time_speech_total += batch_data_time
time_escape_total += time_escape
if pbar:
# pbar.update(1)
pbar.set_description(f"rtf_avg: {time_escape_total/time_speech_total:0.3f}")
- torch.cuda.empty_cache()
+
+ device = next(model.parameters()).device
+ if device.type == "cuda":
+ with torch.cuda.device(device):
+ torch.cuda.empty_cache()
return asr_result_list
def inference_with_vad(self, input, input_len=None, **cfg):
--
Gitblit v1.9.1