kongdeqiang
5 天以前 28ccfbfc51068a663a80764e14074df5edf2b5ba
funasr/auto/auto_model.py
@@ -301,14 +301,27 @@
        res = self.model(*args, kwargs)
        return res
    def generate(self, input, input_len=None, **cfg):
    def generate(self, input, input_len=None, progress_callback=None, **cfg):
        if self.vad_model is None:
            return self.inference(input, input_len=input_len, **cfg)
            return self.inference(
                input, input_len=input_len, progress_callback=progress_callback, **cfg
            )
        else:
            return self.inference_with_vad(input, input_len=input_len, **cfg)
            return self.inference_with_vad(
                input, input_len=input_len, progress_callback=progress_callback, **cfg
            )
    def inference(self, input, input_len=None, model=None, kwargs=None, key=None, **cfg):
    def inference(
        self,
        input,
        input_len=None,
        model=None,
        kwargs=None,
        key=None,
        progress_callback=None,
        **cfg,
    ):
        kwargs = self.kwargs if kwargs is None else kwargs
        if "cache" in kwargs:
            kwargs.pop("cache")
@@ -365,6 +378,11 @@
            if pbar:
                pbar.update(end_idx - beg_idx)
                pbar.set_description(description)
            if progress_callback:
                try:
                    progress_callback(end_idx, num_samples)
                except Exception as e:
                    logging.error(f"progress_callback error: {e}")
            time_speech_total += batch_data_time
            time_escape_total += time_escape