| | |
| | | # if spk_model is not None, build spk model else None |
| | | spk_model = kwargs.get("spk_model", None) |
| | | spk_kwargs = {} if kwargs.get("spk_kwargs", {}) is None else kwargs.get("spk_kwargs", {}) |
| | | cb_kwargs = {} if spk_kwargs.get("cb_kwargs", {}) is None else spk_kwargs.get("cb_kwargs", {}) |
| | | if spk_model is not None: |
| | | logging.info("Building SPK model.") |
| | | spk_kwargs["model"] = spk_model |
| | | spk_kwargs["model_revision"] = kwargs.get("spk_model_revision", "master") |
| | | spk_kwargs["device"] = kwargs["device"] |
| | | spk_model, spk_kwargs = self.build_model(**spk_kwargs) |
| | | self.cb_model = ClusterBackend().to(kwargs["device"]) |
| | | self.cb_model = ClusterBackend(**cb_kwargs).to(kwargs["device"]) |
| | | spk_mode = kwargs.get("spk_mode", "punc_segment") |
| | | if spk_mode not in ["default", "vad_segment", "punc_segment"]: |
| | | logging.error("spk_mode should be one of default, vad_segment and punc_segment.") |