| | |
| | | yield |
| | | |
| | | |
| | | def pad_attractor(att, max_n_speakers): |
| | | C, D = att.shape |
| | | if C < max_n_speakers: |
| | | att = torch.cat([att, torch.zeros(max_n_speakers - C, D).to(torch.float32).to(att.device)], dim=0) |
| | | return att |
| | | |
| | | |
| | | class DiarEENDOLAModel(AbsESPnetModel): |
| | | """CTC-attention hybrid Encoder-Decoder model""" |
| | | |
| | |
| | | # PostNet |
| | | self.PostNet = nn.LSTM(self.max_n_speaker, n_units, 1, batch_first=True) |
| | | self.output_layer = nn.Linear(n_units, mapping_dict['oov'] + 1) |
| | | |
| | | def forward_encoder(self, xs, ilens): |
| | | xs = nn.utils.rnn.pad_sequence(xs, batch_first=True, padding_value=-1) |
| | | pad_shape = xs.shape |
| | | xs_mask = [torch.ones(ilen).to(xs.device) for ilen in ilens] |
| | | xs_mask = torch.nn.utils.rnn.pad_sequence(xs_mask, batch_first=True, padding_value=0).unsqueeze(-2) |
| | | emb = self.encoder(xs, xs_mask) |
| | | emb = torch.split(emb.view(pad_shape[0], pad_shape[1], -1), 1, dim=0) |
| | | emb = [e[0][:ilen] for e, ilen in zip(emb, ilens)] |
| | | return emb |
| | | |
| | | def forward_post_net(self, logits, ilens): |
| | | maxlen = torch.max(ilens).to(torch.int).item() |
| | | logits = nn.utils.rnn.pad_sequence(logits, batch_first=True, padding_value=-1) |
| | | logits = nn.utils.rnn.pack_padded_sequence(logits, ilens, batch_first=True, enforce_sorted=False) |
| | | outputs, (_, _) = self.PostNet(logits) |
| | | outputs = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True, padding_value=-1, total_length=maxlen)[0] |
| | | outputs = [output[:ilens[i].to(torch.int).item()] for i, output in enumerate(outputs)] |
| | | outputs = [self.output_layer(output) for output in outputs] |
| | | return outputs |
| | | |
| | | def forward( |
| | | self, |
| | |
| | | def estimate_sequential(self, |
| | | speech: torch.Tensor, |
| | | speech_lengths: torch.Tensor, |
| | | n_speakers: int, |
| | | shuffle: bool, |
| | | threshold: float, |
| | | n_speakers: int = None, |
| | | shuffle: bool = True, |
| | | threshold: float = 0.5, |
| | | **kwargs): |
| | | speech = [s[:s_len] for s, s_len in zip(speech, speech_lengths)] |
| | | emb = self.forward_core(speech) # list, [(T1, C1), ..., (T1, C1)] |
| | | emb = self.forward_encoder(speech, speech_lengths) |
| | | if shuffle: |
| | | orders = [np.arange(e.shape[0]) for e in emb] |
| | | for order in orders: |
| | | np.random.shuffle(order) |
| | | # e[order]: shuffle后的embeddings, list, [(T1, C1), ..., (T1, C1)] 每个sample的T维度已进行随机顺序交换 |
| | | # attractors, list, hts(论文里的as), [(max_n_speakers, n_units), ..., (max_n_speakers, n_units)] |
| | | # probs, list, [(max_n_speakers, ), ..., (max_n_speakers, ] |
| | | attractors, probs = self.eda.estimate( |
| | | [e[torch.from_numpy(order).to(torch.long).to(xs[0].device)] for e, order in zip(emb, orders)]) |
| | | [e[torch.from_numpy(order).to(torch.long).to(speech[0].device)] for e, order in zip(emb, orders)]) |
| | | else: |
| | | attractors, probs = self.eda.estimate(emb) |
| | | attractors_active = [] |
| | | for p, att, e in zip(probs, attractors, emb): |
| | | if n_speakers and n_speakers >= 0: # 根据指定说话人数, 选择对应数量的ys |
| | | # TODO:在测试有不同数量speaker数的数据集时,考虑改成根据sample来确定具体的speaker数,而不是直接指定 |
| | | # raise NotImplementedError |
| | | if n_speakers and n_speakers >= 0: |
| | | att = att[:n_speakers, ] |
| | | attractors_active.append(att) |
| | | elif threshold is not None: |
| | | silence = torch.nonzero(p < threshold)[0] # 找到第一个输出概率小于阈值的索引, 作为结束, 且值刚好等于说话人数 |
| | | silence = torch.nonzero(p < threshold)[0] |
| | | n_spk = silence[0] if silence.size else None |
| | | att = att[:n_spk, ] |
| | | attractors_active.append(att) |
| | | else: |
| | | NotImplementedError('n_speakers or th has to be given.') |
| | | raw_n_speakers = [att.shape[0] for att in attractors_active] # [C1, C2, ..., CB] |
| | | NotImplementedError('n_speakers or threshold has to be given.') |
| | | raw_n_speakers = [att.shape[0] for att in attractors_active] |
| | | attractors = [ |
| | | pad_attractor(att, self.max_n_speaker) if att.shape[0] <= self.max_n_speaker else att[:self.max_n_speaker] |
| | | for att in attractors_active] |
| | | ys = [torch.matmul(e, att.permute(1, 0)) for e, att in zip(emb, attractors)] |
| | | # ys_eda = [torch.sigmoid(y[:, :n_spk]) for y,n_spk in zip(ys, raw_n_speakers)] |
| | | logits = self.cal_postnet(ys, self.max_n_speaker) |
| | | logits = self.forward_post_net(ys, speech_lengths) |
| | | ys = [self.recover_y_from_powerlabel(logit, raw_n_speaker) for logit, raw_n_speaker in |
| | | zip(logits, raw_n_speakers)] |
| | | |
| | | return ys, emb, attractors, raw_n_speakers |
| | | |
| | | def recover_y_from_powerlabel(self, logit, n_speaker): |
| | | pred = torch.argmax(torch.softmax(logit, dim=-1), dim=-1) # (T, ) |
| | | pred = torch.argmax(torch.softmax(logit, dim=-1), dim=-1) |
| | | oov_index = torch.where(pred == self.mapping_dict['oov'])[0] |
| | | for i in oov_index: |
| | | if i > 0: |
| | |
| | | else: |
| | | pred[i] = 0 |
| | | pred = [self.reporter.inv_mapping_func(i, self.mapping_dict) for i in pred] |
| | | # print(pred) |
| | | decisions = [bin(num)[2:].zfill(self.max_n_speaker)[::-1] for num in pred] |
| | | decisions = torch.from_numpy( |
| | | np.stack([np.array([int(i) for i in dec]) for dec in decisions], axis=0)).to(logit.device).to( |