From b3fb4c0acd5f52a313f024b6f69b8f025c6eddfe Mon Sep 17 00:00:00 2001
From: ming030890 <67713085+ming030890@users.noreply.github.com>
Date: 星期二, 05 八月 2025 17:48:10 +0800
Subject: [PATCH] Allow one to set a custom progress callback (#2609)

---
 tests/test_auto_model.py |   36 ++++++++++++++++++++++++++++++++++--
 1 files changed, 34 insertions(+), 2 deletions(-)

diff --git a/tests/test_auto_model.py b/tests/test_auto_model.py
index 932376b..d17d9ab 100644
--- a/tests/test_auto_model.py
+++ b/tests/test_auto_model.py
@@ -22,7 +22,39 @@
         kwargs["spk_kwargs"] = {"cb_kwargs": {"merge_thr": merge_thr}}
         model = AutoModel(**kwargs)
         self.assertEqual(model.cb_model.model_config['merge_thr'], merge_thr)
-        # res = model.generate(input="/test.wav", 
+        # res = model.generate(input="/test.wav",
         #              batch_size_s=300)
+
+    def test_progress_callback_called(self):
+        class DummyModel:
+            def __init__(self):
+                self.param = torch.nn.Parameter(torch.zeros(1))
+
+            def parameters(self):
+                return iter([self.param])
+
+            def eval(self):
+                pass
+
+            def inference(self, data_in=None, **kwargs):
+                results = [{"text": str(d)} for d in data_in]
+                return results, {"batch_data_time": 1}
+
+        am = AutoModel.__new__(AutoModel)
+        am.model = DummyModel()
+        am.kwargs = {"batch_size": 2, "disable_pbar": True}
+
+        progress = []
+
+        res = AutoModel.inference(
+            am,
+            ["a", "b", "c"],
+            progress_callback=lambda idx, total: progress.append((idx, total)),
+        )
+
+        self.assertEqual(len(progress), 2)
+        self.assertEqual(progress, [(2, 3), (3, 3)])
+
+
 if __name__ == '__main__':
-    unittest.main()
\ No newline at end of file
+    unittest.main()

--
Gitblit v1.9.1