From b3fb4c0acd5f52a313f024b6f69b8f025c6eddfe Mon Sep 17 00:00:00 2001
From: ming030890 <67713085+ming030890@users.noreply.github.com>
Date: 星期二, 05 八月 2025 17:48:10 +0800
Subject: [PATCH] Allow one to set a custom progress callback (#2609)
---
funasr/auto/auto_model.py | 26 +++++++++++--
tests/test_auto_model.py | 36 +++++++++++++++++-
2 files changed, 56 insertions(+), 6 deletions(-)
diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py
index 10d2ef6..a864dad 100644
--- a/funasr/auto/auto_model.py
+++ b/funasr/auto/auto_model.py
@@ -301,14 +301,27 @@
res = self.model(*args, kwargs)
return res
- def generate(self, input, input_len=None, **cfg):
+ def generate(self, input, input_len=None, progress_callback=None, **cfg):
if self.vad_model is None:
- return self.inference(input, input_len=input_len, **cfg)
+ return self.inference(
+ input, input_len=input_len, progress_callback=progress_callback, **cfg
+ )
else:
- return self.inference_with_vad(input, input_len=input_len, **cfg)
+ return self.inference_with_vad(
+ input, input_len=input_len, progress_callback=progress_callback, **cfg
+ )
- def inference(self, input, input_len=None, model=None, kwargs=None, key=None, **cfg):
+ def inference(
+ self,
+ input,
+ input_len=None,
+ model=None,
+ kwargs=None,
+ key=None,
+ progress_callback=None,
+ **cfg,
+ ):
kwargs = self.kwargs if kwargs is None else kwargs
if "cache" in kwargs:
kwargs.pop("cache")
@@ -365,6 +378,11 @@
if pbar:
pbar.update(end_idx - beg_idx)
pbar.set_description(description)
+ if progress_callback:
+ try:
+ progress_callback(end_idx, num_samples)
+ except Exception as e:
+ logging.error(f"progress_callback error: {e}")
time_speech_total += batch_data_time
time_escape_total += time_escape
diff --git a/tests/test_auto_model.py b/tests/test_auto_model.py
index 932376b..d17d9ab 100644
--- a/tests/test_auto_model.py
+++ b/tests/test_auto_model.py
@@ -22,7 +22,39 @@
kwargs["spk_kwargs"] = {"cb_kwargs": {"merge_thr": merge_thr}}
model = AutoModel(**kwargs)
self.assertEqual(model.cb_model.model_config['merge_thr'], merge_thr)
- # res = model.generate(input="/test.wav",
+ # res = model.generate(input="/test.wav",
# batch_size_s=300)
+
+ def test_progress_callback_called(self):
+ class DummyModel:
+ def __init__(self):
+ self.param = torch.nn.Parameter(torch.zeros(1))
+
+ def parameters(self):
+ return iter([self.param])
+
+ def eval(self):
+ pass
+
+ def inference(self, data_in=None, **kwargs):
+ results = [{"text": str(d)} for d in data_in]
+ return results, {"batch_data_time": 1}
+
+ am = AutoModel.__new__(AutoModel)
+ am.model = DummyModel()
+ am.kwargs = {"batch_size": 2, "disable_pbar": True}
+
+ progress = []
+
+ res = AutoModel.inference(
+ am,
+ ["a", "b", "c"],
+ progress_callback=lambda idx, total: progress.append((idx, total)),
+ )
+
+ self.assertEqual(len(progress), 2)
+ self.assertEqual(progress, [(2, 3), (3, 3)])
+
+
if __name__ == '__main__':
- unittest.main()
\ No newline at end of file
+ unittest.main()
--
Gitblit v1.9.1