ming030890
2025-08-05 b3fb4c0acd5f52a313f024b6f69b8f025c6eddfe
Allow one to set a custom progress callback (#2609)

* Allow one to set a custom progress callback

so that they can show it own progrss bar

* Uncomment an existing test

* restore indentation

---------

Co-authored-by: Tony Mak <tony@Tonys-MacBook-Air-1802.local>
2个文件已修改
62 ■■■■■ 已修改文件
funasr/auto/auto_model.py 26 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
tests/test_auto_model.py 36 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/auto/auto_model.py
@@ -301,14 +301,27 @@
        res = self.model(*args, kwargs)
        return res
    def generate(self, input, input_len=None, **cfg):
    def generate(self, input, input_len=None, progress_callback=None, **cfg):
        if self.vad_model is None:
            return self.inference(input, input_len=input_len, **cfg)
            return self.inference(
                input, input_len=input_len, progress_callback=progress_callback, **cfg
            )
        else:
            return self.inference_with_vad(input, input_len=input_len, **cfg)
            return self.inference_with_vad(
                input, input_len=input_len, progress_callback=progress_callback, **cfg
            )
    def inference(self, input, input_len=None, model=None, kwargs=None, key=None, **cfg):
    def inference(
        self,
        input,
        input_len=None,
        model=None,
        kwargs=None,
        key=None,
        progress_callback=None,
        **cfg,
    ):
        kwargs = self.kwargs if kwargs is None else kwargs
        if "cache" in kwargs:
            kwargs.pop("cache")
@@ -365,6 +378,11 @@
            if pbar:
                pbar.update(end_idx - beg_idx)
                pbar.set_description(description)
            if progress_callback:
                try:
                    progress_callback(end_idx, num_samples)
                except Exception as e:
                    logging.error(f"progress_callback error: {e}")
            time_speech_total += batch_data_time
            time_escape_total += time_escape
tests/test_auto_model.py
@@ -22,7 +22,39 @@
        kwargs["spk_kwargs"] = {"cb_kwargs": {"merge_thr": merge_thr}}
        model = AutoModel(**kwargs)
        self.assertEqual(model.cb_model.model_config['merge_thr'], merge_thr)
        # res = model.generate(input="/test.wav",
        # res = model.generate(input="/test.wav",
        #              batch_size_s=300)
    def test_progress_callback_called(self):
        class DummyModel:
            def __init__(self):
                self.param = torch.nn.Parameter(torch.zeros(1))
            def parameters(self):
                return iter([self.param])
            def eval(self):
                pass
            def inference(self, data_in=None, **kwargs):
                results = [{"text": str(d)} for d in data_in]
                return results, {"batch_data_time": 1}
        am = AutoModel.__new__(AutoModel)
        am.model = DummyModel()
        am.kwargs = {"batch_size": 2, "disable_pbar": True}
        progress = []
        res = AutoModel.inference(
            am,
            ["a", "b", "c"],
            progress_callback=lambda idx, total: progress.append((idx, total)),
        )
        self.assertEqual(len(progress), 2)
        self.assertEqual(progress, [(2, 3), (3, 3)])
if __name__ == '__main__':
    unittest.main()
    unittest.main()