From b3fb4c0acd5f52a313f024b6f69b8f025c6eddfe Mon Sep 17 00:00:00 2001
From: ming030890 <67713085+ming030890@users.noreply.github.com>
Date: 星期二, 05 八月 2025 17:48:10 +0800
Subject: [PATCH] Allow one to set a custom progress callback (#2609)

---
 funasr/auto/auto_model.py |   26 ++++++++++++++++++++++----
 1 files changed, 22 insertions(+), 4 deletions(-)

diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py
index 10d2ef6..a864dad 100644
--- a/funasr/auto/auto_model.py
+++ b/funasr/auto/auto_model.py
@@ -301,14 +301,27 @@
         res = self.model(*args, kwargs)
         return res
 
-    def generate(self, input, input_len=None, **cfg):
+    def generate(self, input, input_len=None, progress_callback=None, **cfg):
         if self.vad_model is None:
-            return self.inference(input, input_len=input_len, **cfg)
+            return self.inference(
+                input, input_len=input_len, progress_callback=progress_callback, **cfg
+            )
 
         else:
-            return self.inference_with_vad(input, input_len=input_len, **cfg)
+            return self.inference_with_vad(
+                input, input_len=input_len, progress_callback=progress_callback, **cfg
+            )
 
-    def inference(self, input, input_len=None, model=None, kwargs=None, key=None, **cfg):
+    def inference(
+        self,
+        input,
+        input_len=None,
+        model=None,
+        kwargs=None,
+        key=None,
+        progress_callback=None,
+        **cfg,
+    ):
         kwargs = self.kwargs if kwargs is None else kwargs
         if "cache" in kwargs:
             kwargs.pop("cache")
@@ -365,6 +378,11 @@
             if pbar:
                 pbar.update(end_idx - beg_idx)
                 pbar.set_description(description)
+            if progress_callback:
+                try:
+                    progress_callback(end_idx, num_samples)
+                except Exception as e:
+                    logging.error(f"progress_callback error: {e}")
             time_speech_total += batch_data_time
             time_escape_total += time_escape
 

--
Gitblit v1.9.1