Merge pull request #250 from alibaba-damo-academy/dev_dzh
Dev dzh
| | |
| | | self, |
| | | diar_train_config: Union[Path, str] = None, |
| | | diar_model_file: Union[Path, str] = None, |
| | | device: str = "cpu", |
| | | device: Union[str, torch.device] = "cpu", |
| | | batch_size: int = 1, |
| | | dtype: str = "float32", |
| | | streaming: bool = False, |
| | |
| | | # little-endian order: lower bit first |
| | | return (np.array(list(b)[::-1]) == '1').astype(dtype) |
| | | |
| | | return np.row_stack([int2vec(int(x), vec_dim) for x in seq]) |
| | | # process oov |
| | | seq = np.array([int(x) for x in seq]) |
| | | new_seq = [] |
| | | for i, x in enumerate(seq): |
| | | if x < 2 ** vec_dim: |
| | | new_seq.append(x) |
| | | else: |
| | | idx_list = np.where(seq < 2 ** vec_dim)[0] |
| | | idx = np.abs(idx_list - i).argmin() |
| | | new_seq.append(seq[idx_list[idx]]) |
| | | return np.row_stack([int2vec(x, vec_dim) for x in new_seq]) |
| | | |
| | | def post_processing(self, raw_logits: torch.Tensor, spk_num: int): |
| | | def post_processing(self, raw_logits: torch.Tensor, spk_num: int, output_format: str = "speaker_turn"): |
| | | logits_idx = raw_logits.argmax(-1) # B, T, vocab_size -> B, T |
| | | # upsampling outputs to match inputs |
| | | ut = logits_idx.shape[1] * self.diar_model.encoder.time_ds_ratio |
| | |
| | | ).squeeze(1).long() |
| | | logits_idx = logits_idx[0].tolist() |
| | | pse_labels = [self.token_list[x] for x in logits_idx] |
| | | if output_format == "pse_labels": |
| | | return pse_labels, None |
| | | |
| | | multi_labels = self.seq2arr(pse_labels, spk_num)[:, :spk_num] # remove padding speakers |
| | | multi_labels = self.smooth_multi_labels(multi_labels) |
| | | if output_format == "binary_labels": |
| | | return multi_labels, None |
| | | |
| | | spk_list = ["spk{}".format(i + 1) for i in range(spk_num)] |
| | | spk_turns = self.calc_spk_turns(multi_labels, spk_list) |
| | | results = OrderedDict() |
| | |
| | | self, |
| | | speech: Union[torch.Tensor, np.ndarray], |
| | | profile: Union[torch.Tensor, np.ndarray], |
| | | output_format: str = "speaker_turn" |
| | | ): |
| | | """Inference |
| | | |
| | |
| | | batch = to_device(batch, device=self.device) |
| | | |
| | | logits = self.diar_model.prediction_forward(**batch) |
| | | results, pse_labels = self.post_processing(logits, profile.shape[1]) |
| | | results, pse_labels = self.post_processing(logits, profile.shape[1], output_format) |
| | | |
| | | return results, pse_labels |
| | | |
| | |
| | | pse_label_writer = open("{}/labels.txt".format(output_path), "w") |
| | | logging.info("Start to diarize...") |
| | | result_list = [] |
| | | for keys, batch in loader: |
| | | for idx, (keys, batch) in enumerate(loader): |
| | | assert isinstance(batch, dict), type(batch) |
| | | assert all(isinstance(s, str) for s in keys), keys |
| | | _bs = len(next(iter(batch.values()))) |
| | |
| | | pse_label_writer.write("{} {}\n".format(key, " ".join(pse_labels))) |
| | | pse_label_writer.flush() |
| | | |
| | | if idx % 100 == 0: |
| | | logging.info("Processing {:5d}: {}".format(idx, key)) |
| | | |
| | | if output_path is not None: |
| | | output_writer.close() |
| | | pse_label_writer.close() |
| | |
| | | from typing import Iterator |
| | | from typing import Tuple |
| | | from typing import Union |
| | | from typing import List |
| | | |
| | | import kaldiio |
| | | import numpy as np |
| | |
| | | non_iterable_list = [] |
| | | self.path_name_type_list = [] |
| | | |
| | | if not isinstance(path_name_type_list[0], Tuple): |
| | | if not isinstance(path_name_type_list[0], (Tuple, List)): |
| | | path = path_name_type_list[0] |
| | | name = path_name_type_list[1] |
| | | _type = path_name_type_list[2] |
| | |
| | | normalize_speech_speaker: bool = False, |
| | | ignore_id: int = -1, |
| | | speaker_discrimination_loss_weight: float = 1.0, |
| | | inter_score_loss_weight: float = 0.0 |
| | | inter_score_loss_weight: float = 0.0, |
| | | inputs_type: str = "raw", |
| | | ): |
| | | assert check_argument_types() |
| | | |
| | |
| | | ) |
| | | self.criterion_bce = SequenceBinaryCrossEntropy(normalize_length=length_normalized_loss) |
| | | self.pse_embedding = self.generate_pse_embedding() |
| | | # self.register_buffer("pse_embedding", pse_embedding) |
| | | self.power_weight = torch.from_numpy(2 ** np.arange(max_spk_num)[np.newaxis, np.newaxis, :]).float() |
| | | # self.register_buffer("power_weight", power_weight) |
| | | self.int_token_arr = torch.from_numpy(np.array(self.token_list).astype(int)[np.newaxis, np.newaxis, :]).int() |
| | | # self.register_buffer("int_token_arr", int_token_arr) |
| | | self.speaker_discrimination_loss_weight = speaker_discrimination_loss_weight |
| | | self.inter_score_loss_weight = inter_score_loss_weight |
| | | self.forward_steps = 0 |
| | | self.inputs_type = inputs_type |
| | | |
| | | def generate_pse_embedding(self): |
| | | embedding = np.zeros((len(self.token_list), self.max_spk_num), dtype=np.float) |
| | |
| | | binary_labels: (Batch, frames, max_spk_num) |
| | | binary_labels_lengths: (Batch,) |
| | | """ |
| | | assert speech.shape[0] == binary_labels.shape[0], (speech.shape, binary_labels.shape) |
| | | assert speech.shape[0] <= binary_labels.shape[0], (speech.shape, binary_labels.shape) |
| | | batch_size = speech.shape[0] |
| | | self.forward_steps = self.forward_steps + 1 |
| | | if self.pse_embedding.device != speech.device: |
| | | self.pse_embedding = self.pse_embedding.to(speech.device) |
| | | self.power_weight = self.power_weight.to(speech.device) |
| | | self.int_token_arr = self.int_token_arr.to(speech.device) |
| | | |
| | | # 1. Network forward |
| | | pred, inter_outputs = self.prediction_forward( |
| | | speech, speech_lengths, |
| | |
| | | # the sequence length of 'pred' might be slightly less than the |
| | | # length of 'spk_labels'. Here we force them to be equal. |
| | | length_diff_tolerance = 2 |
| | | length_diff = pse_labels.shape[1] - pred.shape[1] |
| | | if 0 < length_diff <= length_diff_tolerance: |
| | | pse_labels = pse_labels[:, 0: pred.shape[1]] |
| | | length_diff = abs(pse_labels.shape[1] - pred.shape[1]) |
| | | if length_diff <= length_diff_tolerance: |
| | | min_len = min(pred.shape[1], pse_labels.shape[1]) |
| | | pse_labels = pse_labels[:, :min_len] |
| | | pred = pred[:, :min_len] |
| | | cd_score = cd_score[:, :min_len] |
| | | ci_score = ci_score[:, :min_len] |
| | | |
| | | loss_diar = self.classification_loss(pred, pse_labels, binary_labels_lengths) |
| | | loss_spk_dis = self.speaker_discrimination_loss(profile, profile_lengths) |
| | |
| | | speech: torch.Tensor, |
| | | speech_lengths: torch.Tensor, |
| | | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | if self.encoder is not None: |
| | | if self.encoder is not None and self.inputs_type == "raw": |
| | | speech, speech_lengths = self.encode(speech, speech_lengths) |
| | | speech_mask = ~make_pad_mask(speech_lengths, maxlen=speech.shape[1]) |
| | | speech_mask = speech_mask.to(speech.device).unsqueeze(-1).float() |
| | |
| | | config_file: Union[Path, str] = None, |
| | | model_file: Union[Path, str] = None, |
| | | cmvn_file: Union[Path, str] = None, |
| | | device: str = "cpu", |
| | | device: Union[str, torch.device] = "cpu", |
| | | ): |
| | | """Build model from the files. |
| | | |
| | |
| | | model.load_state_dict(model_dict) |
| | | else: |
| | | model_dict = torch.load(model_file, map_location=device) |
| | | model_dict = cls.fileter_model_dict(model_dict, model.state_dict()) |
| | | model.load_state_dict(model_dict) |
| | | if model_name_pth is not None and not os.path.exists(model_name_pth): |
| | | torch.save(model_dict, model_name_pth) |
| | |
| | | return model, args |
| | | |
| | | @classmethod |
| | | def fileter_model_dict(cls, src_dict: dict, dest_dict: dict): |
| | | from collections import OrderedDict |
| | | new_dict = OrderedDict() |
| | | for key, value in src_dict.items(): |
| | | if key in dest_dict: |
| | | new_dict[key] = value |
| | | else: |
| | | logging.info("{} is no longer needed in this model.".format(key)) |
| | | for key, value in dest_dict.items(): |
| | | if key not in new_dict: |
| | | logging.warning("{} is missed in checkpoint.".format(key)) |
| | | return new_dict |
| | | |
| | | @classmethod |
| | | def convert_tf2torch( |
| | | cls, |
| | | model, |