| | |
| | | self.preset_spk_num = kwargs.get("preset_spk_num", None) |
| | | if self.preset_spk_num: |
| | | logging.warning("Using preset speaker number: {}".format(self.preset_spk_num)) |
| | | logging.warning("Many to print when using speaker model...") |
| | | |
| | | self.kwargs = kwargs |
| | | self.model = model |
| | |
| | | speech_j, speech_lengths_j = slice_padding_audio_samples(speech, speech_lengths, sorted_data[beg_idx:end_idx]) |
| | | results = self.inference(speech_j, input_len=None, model=model, kwargs=kwargs, disable_pbar=True, **cfg) |
| | | if self.spk_model is not None: |
| | | |
| | | |
| | | # compose vad segments: [[start_time_sec, end_time_sec, speech], [...]] |
| | | for _b in range(len(speech_j)): |
| | | vad_segments = [[sorted_data[beg_idx:end_idx][_b][0][0]/1000.0, |
| | |
| | | if len(results) < 1: |
| | | continue |
| | | results_sorted.extend(results) |
| | | |
| | | |
| | | |
| | | end_asr_total = time.time() |
| | | time_escape_total_per_sample = end_asr_total - beg_asr_total |
| | |
| | | f"time_speech_total_per_sample: {time_speech_total_per_sample: 0.3f}, " |
| | | f"time_escape_total_per_sample: {time_escape_total_per_sample:0.3f}") |
| | | |
| | | |
| | | restored_data = [0] * n |
| | | for j in range(n): |
| | | index = sorted_data[j][1] |
| | |
| | | result[k] = restored_data[j][k] |
| | | else: |
| | | result[k] = torch.cat([result[k], restored_data[j][k]], dim=0) |
| | | elif k == 'raw_text': |
| | | elif 'text' in k: |
| | | if k not in result: |
| | | result[k] = restored_data[j][k] |
| | | else: |
| | |
| | | if self.punc_model is not None: |
| | | self.punc_kwargs.update(cfg) |
| | | punc_res = self.inference(result["text"], model=self.punc_model, kwargs=self.punc_kwargs, **cfg) |
| | | import copy; raw_text = copy.copy(result["text"]) |
| | | result["text"] = punc_res[0]["text"] |
| | | |
| | | |
| | | # speaker embedding cluster after resorted |
| | | if self.spk_model is not None: |
| | | all_segments = sorted(all_segments, key=lambda x: x[0]) |
| | |
| | | labels = self.cb_model(spk_embedding.cpu(), oracle_num=self.preset_spk_num) |
| | | del result['spk_embedding'] |
| | | sv_output = postprocess(all_segments, None, labels, spk_embedding.cpu()) |
| | | if self.spk_mode == 'vad_segment': |
| | | if self.spk_mode == 'vad_segment': # recover sentence_list |
| | | sentence_list = [] |
| | | for res, vadsegment in zip(restored_data, vadsegments): |
| | | sentence_list.append({"start": vadsegment[0],\ |
| | | "end": vadsegment[1], |
| | | "sentence": res['raw_text'], |
| | | "timestamp": res['timestamp']}) |
| | | else: # punc_segment |
| | | elif self.spk_mode == 'punc_segment': |
| | | sentence_list = timestamp_sentence(punc_res[0]['punc_array'], \ |
| | | result['timestamp'], \ |
| | | result['raw_text']) |
| | | distribute_spk(sentence_list, sv_output) |
| | | result['sentence_info'] = sentence_list |
| | | elif kwargs.get("sentence_timestamp", False): |
| | | sentence_list = timestamp_sentence(punc_res[0]['punc_array'], \ |
| | | result['timestamp'], \ |
| | | result['raw_text']) |
| | | result['sentence_info'] = sentence_list |
| | | |
| | | result["key"] = key |
| | | results_ret_list.append(result) |