| | |
| | | self, |
| | | diar_train_config: Union[Path, str] = None, |
| | | diar_model_file: Union[Path, str] = None, |
| | | device: str = "cpu", |
| | | device: Union[str, torch.device] = "cpu", |
| | | batch_size: int = 1, |
| | | dtype: str = "float32", |
| | | streaming: bool = False, |
| | |
| | | # little-endian order: lower bit first |
| | | return (np.array(list(b)[::-1]) == '1').astype(dtype) |
| | | |
| | | return np.row_stack([int2vec(int(x), vec_dim) for x in seq]) |
| | | # process oov |
| | | seq = np.array([int(x) for x in seq]) |
| | | new_seq = [] |
| | | for i, x in enumerate(seq): |
| | | if x < 2 ** vec_dim: |
| | | new_seq.append(x) |
| | | else: |
| | | idx_list = np.where(seq < 2 ** vec_dim)[0] |
| | | idx = np.abs(idx_list - i).argmin() |
| | | new_seq.append(seq[idx_list[idx]]) |
| | | return np.row_stack([int2vec(x, vec_dim) for x in new_seq]) |
| | | |
| | | def post_processing(self, raw_logits: torch.Tensor, spk_num: int): |
| | | def post_processing(self, raw_logits: torch.Tensor, spk_num: int, output_format: str = "speaker_turn"): |
| | | logits_idx = raw_logits.argmax(-1) # B, T, vocab_size -> B, T |
| | | # upsampling outputs to match inputs |
| | | ut = logits_idx.shape[1] * self.diar_model.encoder.time_ds_ratio |
| | |
| | | ).squeeze(1).long() |
| | | logits_idx = logits_idx[0].tolist() |
| | | pse_labels = [self.token_list[x] for x in logits_idx] |
| | | if output_format == "pse_labels": |
| | | return pse_labels, None |
| | | |
| | | multi_labels = self.seq2arr(pse_labels, spk_num)[:, :spk_num] # remove padding speakers |
| | | multi_labels = self.smooth_multi_labels(multi_labels) |
| | | if output_format == "binary_labels": |
| | | return multi_labels, None |
| | | |
| | | spk_list = ["spk{}".format(i + 1) for i in range(spk_num)] |
| | | spk_turns = self.calc_spk_turns(multi_labels, spk_list) |
| | | results = OrderedDict() |
| | |
| | | self, |
| | | speech: Union[torch.Tensor, np.ndarray], |
| | | profile: Union[torch.Tensor, np.ndarray], |
| | | output_format: str = "speaker_turn" |
| | | ): |
| | | """Inference |
| | | |
| | |
| | | batch = to_device(batch, device=self.device) |
| | | |
| | | logits = self.diar_model.prediction_forward(**batch) |
| | | results, pse_labels = self.post_processing(logits, profile.shape[1]) |
| | | results, pse_labels = self.post_processing(logits, profile.shape[1], output_format) |
| | | |
| | | return results, pse_labels |
| | | |
| | |
| | | pse_label_writer = open("{}/labels.txt".format(output_path), "w") |
| | | logging.info("Start to diarize...") |
| | | result_list = [] |
| | | for keys, batch in loader: |
| | | for idx, (keys, batch) in enumerate(loader): |
| | | assert isinstance(batch, dict), type(batch) |
| | | assert all(isinstance(s, str) for s in keys), keys |
| | | _bs = len(next(iter(batch.values()))) |
| | |
| | | pse_label_writer.write("{} {}\n".format(key, " ".join(pse_labels))) |
| | | pse_label_writer.flush() |
| | | |
| | | if idx % 100 == 0: |
| | | logging.info("Processing {:5d}: {}".format(idx, key)) |
| | | |
| | | if output_path is not None: |
| | | output_writer.close() |
| | | pse_label_writer.close() |