| | |
| | | # if spk_model is not None, build spk model else None |
| | | spk_model = kwargs.get("spk_model", None) |
| | | spk_kwargs = {} if kwargs.get("spk_kwargs", {}) is None else kwargs.get("spk_kwargs", {}) |
| | | cb_kwargs = {} if spk_kwargs.get("cb_kwargs", {}) is None else spk_kwargs.get("cb_kwargs", {}) |
| | | if spk_model is not None: |
| | | logging.info("Building SPK model.") |
| | | spk_kwargs["model"] = spk_model |
| | | spk_kwargs["model_revision"] = kwargs.get("spk_model_revision", "master") |
| | | spk_kwargs["device"] = kwargs["device"] |
| | | spk_model, spk_kwargs = self.build_model(**spk_kwargs) |
| | | self.cb_model = ClusterBackend().to(kwargs["device"]) |
| | | self.cb_model = ClusterBackend(**cb_kwargs).to(kwargs["device"]) |
| | | spk_mode = kwargs.get("spk_mode", "punc_segment") |
| | | if spk_mode not in ["default", "vad_segment", "punc_segment"]: |
| | | logging.error("spk_mode should be one of default, vad_segment and punc_segment.") |
| | |
| | | tokenizers_build = [] |
| | | vocab_sizes = [] |
| | | token_lists = [] |
| | | |
| | | ### === only for kws === |
| | | token_list_files = kwargs.get("token_lists", []) |
| | | seg_dicts = kwargs.get("seg_dicts", []) |
| | |
| | | |
| | | ### === only for kws === |
| | | if len(token_list_files) > 1: |
| | | tokenizer_conf.token_list = token_list_files[i] |
| | | tokenizer_conf["token_list"] = token_list_files[i] |
| | | if len(seg_dicts) > 1: |
| | | tokenizer_conf.seg_dict = seg_dicts[i] |
| | | tokenizer_conf["seg_dict"] = seg_dicts[i] |
| | | ### === only for kws === |
| | | |
| | | tokenizer = tokenizer_class(**tokenizer_conf) |