kongdeqiang
5 天以前 28ccfbfc51068a663a80764e14074df5edf2b5ba
funasr/auto/auto_model.py
@@ -182,7 +182,10 @@
        set_all_random_seed(kwargs.get("seed", 0))
        device = kwargs.get("device", "cuda")
        if not torch.cuda.is_available() or kwargs.get("ngpu", 1) == 0:
        if ((device =="cuda" and not torch.cuda.is_available())
            or (device == "xpu" and not torch.xpu.is_available())
            or (device == "mps" and not torch.backends.mps.is_available())
            or kwargs.get("ngpu", 1) == 0):
            device = "cpu"
            kwargs["batch_size"] = 1
        kwargs["device"] = device
@@ -298,14 +301,27 @@
        res = self.model(*args, kwargs)
        return res
    def generate(self, input, input_len=None, **cfg):
    def generate(self, input, input_len=None, progress_callback=None, **cfg):
        if self.vad_model is None:
            return self.inference(input, input_len=input_len, **cfg)
            return self.inference(
                input, input_len=input_len, progress_callback=progress_callback, **cfg
            )
        else:
            return self.inference_with_vad(input, input_len=input_len, **cfg)
            return self.inference_with_vad(
                input, input_len=input_len, progress_callback=progress_callback, **cfg
            )
    def inference(self, input, input_len=None, model=None, kwargs=None, key=None, **cfg):
    def inference(
        self,
        input,
        input_len=None,
        model=None,
        kwargs=None,
        key=None,
        progress_callback=None,
        **cfg,
    ):
        kwargs = self.kwargs if kwargs is None else kwargs
        if "cache" in kwargs:
            kwargs.pop("cache")
@@ -362,6 +378,11 @@
            if pbar:
                pbar.update(end_idx - beg_idx)
                pbar.set_description(description)
            if progress_callback:
                try:
                    progress_callback(end_idx, num_samples)
                except Exception as e:
                    logging.error(f"progress_callback error: {e}")
            time_speech_total += batch_data_time
            time_escape_total += time_escape
@@ -549,41 +570,8 @@
            # speaker embedding cluster after resorted
            if self.spk_model is not None and kwargs.get("return_spk_res", True):
                # 1. 先检查时间戳
                has_timestamp = (
                    hasattr(self.model, "internal_punc") or
                    self.punc_model is not None or
                    "timestamp" in result
                )
                if not has_timestamp:
                    logging.error("Need timestamp support...")
                    return results_ret_list
                # 2. 初始化 punc_res
                punc_res = None
                # 3. 根据不同情况设置 punc_res
                if hasattr(self.model, "internal_punc"):
                    punc_res = [{
                        "text": result["text"],
                        "punc_array": result.get("punc_array", []),
                        "timestamp": result.get("timestamp", [])
                    }]
                elif self.punc_model is not None:
                    punc_res = self.inference(
                        result["text"],
                        model=self.punc_model,
                        kwargs=self.punc_kwargs,
                        **cfg
                    )
                else:
                    # 如果只有时间戳,创建一个基本的 punc_res
                    punc_res = [{
                        "text": result["text"],
                        "punc_array": [],  # 空的标点数组
                        "timestamp": result["timestamp"]
                    }]
                if raw_text is None:
                    logging.error("Missing punc_model, which is required by spk_model.")
                all_segments = sorted(all_segments, key=lambda x: x[0])
                spk_embedding = result["spk_embedding"]
                labels = self.cb_model(