| | |
| | | # if spk_model is not None, build spk model else None |
| | | spk_model = kwargs.get("spk_model", None) |
| | | spk_kwargs = {} if kwargs.get("spk_kwargs", {}) is None else kwargs.get("spk_kwargs", {}) |
| | | cb_kwargs = ( |
| | | {} if spk_kwargs.get("cb_kwargs", {}) is None else spk_kwargs.get("cb_kwargs", {}) |
| | | ) |
| | | if spk_model is not None: |
| | | logging.info("Building SPK model.") |
| | | spk_kwargs["model"] = spk_model |
| | | spk_kwargs["model_revision"] = kwargs.get("spk_model_revision", "master") |
| | | spk_kwargs["device"] = kwargs["device"] |
| | | spk_model, spk_kwargs = self.build_model(**spk_kwargs) |
| | | self.cb_model = ClusterBackend().to(kwargs["device"]) |
| | | self.cb_model = ClusterBackend(**cb_kwargs).to(kwargs["device"]) |
| | | spk_mode = kwargs.get("spk_mode", "punc_segment") |
| | | if spk_mode not in ["default", "vad_segment", "punc_segment"]: |
| | | logging.error("spk_mode should be one of default, vad_segment and punc_segment.") |
| | |
| | | tokenizers_build = [] |
| | | vocab_sizes = [] |
| | | token_lists = [] |
| | | |
| | | ### === only for kws === |
| | | token_list_files = kwargs.get("token_lists", []) |
| | | seg_dicts = kwargs.get("seg_dicts", []) |
| | |
| | | |
| | | ### === only for kws === |
| | | if len(token_list_files) > 1: |
| | | tokenizer_conf.token_list = token_list_files[i] |
| | | tokenizer_conf["token_list"] = token_list_files[i] |
| | | if len(seg_dicts) > 1: |
| | | tokenizer_conf.seg_dict = seg_dicts[i] |
| | | tokenizer_conf["seg_dict"] = seg_dicts[i] |
| | | ### === only for kws === |
| | | |
| | | tokenizer = tokenizer_class(**tokenizer_conf) |
| | |
| | | if token_list is not None: |
| | | vocab_size = len(token_list) |
| | | |
| | | if vocab_size == -1 and hasattr(tokenizer, "get_vocab_size"): |
| | | vocab_size = tokenizer.get_vocab_size() |
| | | if vocab_size == -1 and hasattr(tokenizer, "get_vocab_size"): |
| | | vocab_size = tokenizer.get_vocab_size() |
| | | token_lists.append(token_list) |
| | | vocab_sizes.append(vocab_size) |
| | | |
| | |
| | | if pbar: |
| | | # pbar.update(1) |
| | | pbar.set_description(f"rtf_avg: {time_escape_total/time_speech_total:0.3f}") |
| | | torch.cuda.empty_cache() |
| | | |
| | | device = next(model.parameters()).device |
| | | if device.type == "cuda": |
| | | with torch.cuda.device(device): |
| | | torch.cuda.empty_cache() |
| | | return asr_result_list |
| | | |
| | | def inference_with_vad(self, input, input_len=None, **cfg): |