| | |
| | | from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank |
| | | from funasr.utils.timestamp_tools import timestamp_sentence |
| | | from funasr.models.campplus.utils import sv_chunk, postprocess, distribute_spk |
| | | from funasr.models.campplus.cluster_backend import ClusterBackend |
| | | try: |
| | | from funasr.models.campplus.cluster_backend import ClusterBackend |
| | | except: |
| | | print("If you want to use the speaker diarization, please `pip install hdbscan`") |
| | | |
| | | |
| | | def prepare_data_iterator(data_in, input_len=None, data_type=None, key=None): |
| | |
| | | class AutoModel: |
| | | |
| | | def __init__(self, **kwargs): |
| | | tables.print() |
| | | if not kwargs.get("disable_log", False): |
| | | tables.print() |
| | | |
| | | model, kwargs = self.build_model(**kwargs) |
| | | |
| | |
| | | if spk_mode not in ["default", "vad_segment", "punc_segment"]: |
| | | logging.error("spk_mode should be one of default, vad_segment and punc_segment.") |
| | | self.spk_mode = spk_mode |
| | | self.preset_spk_num = kwargs.get("preset_spk_num", None) |
| | | if self.preset_spk_num: |
| | | logging.warning("Using preset speaker number: {}".format(self.preset_spk_num)) |
| | | |
| | | self.kwargs = kwargs |
| | | self.model = model |
| | |
| | | self.spk_model = spk_model |
| | | self.spk_kwargs = spk_kwargs |
| | | self.model_path = kwargs.get("model_path") |
| | | |
| | | |
| | | |
| | | def build_model(self, **kwargs): |
| | | assert "model" in kwargs |
| | |
| | | set_all_random_seed(kwargs.get("seed", 0)) |
| | | |
| | | device = kwargs.get("device", "cuda") |
| | | if not torch.cuda.is_available() or kwargs.get("ngpu", 0) == 0: |
| | | if not torch.cuda.is_available() or kwargs.get("ngpu", 1) == 0: |
| | | device = "cpu" |
| | | kwargs["batch_size"] = 1 |
| | | kwargs["device"] = device |
| | |
| | | # build model |
| | | model_class = tables.model_classes.get(kwargs["model"]) |
| | | model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=vocab_size) |
| | | model.eval() |
| | | |
| | | model.to(device) |
| | | |
| | | # init_param |
| | |
| | | res = self.model(*args, kwargs) |
| | | return res |
| | | |
| | | |
| | | |
| | | def generate(self, input, input_len=None, **cfg): |
| | | if self.vad_model is None: |
| | | return self.inference(input, input_len=input_len, **cfg) |
| | |
| | | kwargs = self.kwargs if kwargs is None else kwargs |
| | | kwargs.update(cfg) |
| | | model = self.model if model is None else model |
| | | model.eval() |
| | | |
| | | batch_size = kwargs.get("batch_size", 1) |
| | | # if kwargs.get("device", "cpu") == "cpu": |
| | |
| | | data_batch = data_list[beg_idx:end_idx] |
| | | key_batch = key_list[beg_idx:end_idx] |
| | | batch = {"data_in": data_batch, "key": key_batch} |
| | | if (end_idx - beg_idx) == 1 and isinstance(data_batch[0], torch.Tensor): # fbank |
| | | if (end_idx - beg_idx) == 1 and kwargs.get("data_type", None) == "fbank": # fbank |
| | | batch["data_in"] = data_batch[0] |
| | | batch["data_lengths"] = input_len |
| | | |
| | |
| | | result["text"] = punc_res[0]["text"] |
| | | |
| | | # speaker embedding cluster after resorted |
| | | if self.spk_model is not None: |
| | | if self.spk_model is not None and kwargs.get('return_spk_res', True): |
| | | all_segments = sorted(all_segments, key=lambda x: x[0]) |
| | | spk_embedding = result['spk_embedding'] |
| | | labels = self.cb_model(spk_embedding.cpu(), oracle_num=self.preset_spk_num) |
| | | del result['spk_embedding'] |
| | | labels = self.cb_model(spk_embedding.cpu(), oracle_num=kwargs.get('preset_spk_num', None)) |
| | | # del result['spk_embedding'] |
| | | sv_output = postprocess(all_segments, None, labels, spk_embedding.cpu()) |
| | | if self.spk_mode == 'vad_segment': # recover sentence_list |
| | | sentence_list = [] |
| | |
| | | result['timestamp'], \ |
| | | result['raw_text']) |
| | | result['sentence_info'] = sentence_list |
| | | del result['spk_embedding'] |
| | | |
| | | result["key"] = key |
| | | results_ret_list.append(result) |