From 3fb2ca8378fc21b8f8dc3a451797a54ed42132d2 Mon Sep 17 00:00:00 2001
From: hnluo <haoneng.lhn@alibaba-inc.com>
Date: 星期五, 10 三月 2023 15:28:18 +0800
Subject: [PATCH] Merge pull request #204 from alibaba-damo-academy/dev_wjm
---
funasr/modules/eend_ola/__init__.py | 0
funasr/modules/eend_ola/encoder_decoder_attractor.py | 50 +++++
funasr/modules/eend_ola/encoder.py | 127 ++++++++++++++
funasr/modules/eend_ola/utils/report.py | 159 +++++++++++++++++
funasr/modules/eend_ola/utils/power.py | 95 ++++++++++
funasr/modules/eend_ola/utils/losses.py | 67 +++++++
6 files changed, 498 insertions(+), 0 deletions(-)
diff --git a/funasr/modules/eend_ola/__init__.py b/funasr/modules/eend_ola/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/funasr/modules/eend_ola/__init__.py
diff --git a/funasr/modules/eend_ola/encoder.py b/funasr/modules/eend_ola/encoder.py
new file mode 100644
index 0000000..17d11ac
--- /dev/null
+++ b/funasr/modules/eend_ola/encoder.py
@@ -0,0 +1,127 @@
+import math
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+
+class MultiHeadSelfAttention(nn.Module):
+ def __init__(self, n_units, h=8, dropout_rate=0.1):
+ super(MultiHeadSelfAttention, self).__init__()
+ self.linearQ = nn.Linear(n_units, n_units)
+ self.linearK = nn.Linear(n_units, n_units)
+ self.linearV = nn.Linear(n_units, n_units)
+ self.linearO = nn.Linear(n_units, n_units)
+ self.d_k = n_units // h
+ self.h = h
+ self.dropout = nn.Dropout(dropout_rate)
+
+ def __call__(self, x, batch_size, x_mask):
+ q = self.linearQ(x).view(batch_size, -1, self.h, self.d_k)
+ k = self.linearK(x).view(batch_size, -1, self.h, self.d_k)
+ v = self.linearV(x).view(batch_size, -1, self.h, self.d_k)
+ scores = torch.matmul(
+ q.permute(0, 2, 1, 3), k.permute(0, 2, 3, 1)) / math.sqrt(self.d_k)
+ if x_mask is not None:
+ x_mask = x_mask.unsqueeze(1)
+ scores = scores.masked_fill(x_mask == 0, -1e9)
+ self.att = F.softmax(scores, dim=3)
+ p_att = self.dropout(self.att)
+ x = torch.matmul(p_att, v.permute(0, 2, 1, 3))
+ x = x.permute(0, 2, 1, 3).contiguous().view(-1, self.h * self.d_k)
+ return self.linearO(x)
+
+
+class PositionwiseFeedForward(nn.Module):
+ def __init__(self, n_units, d_units, dropout_rate):
+ super(PositionwiseFeedForward, self).__init__()
+ self.linear1 = nn.Linear(n_units, d_units)
+ self.linear2 = nn.Linear(d_units, n_units)
+ self.dropout = nn.Dropout(dropout_rate)
+
+ def __call__(self, x):
+ return self.linear2(self.dropout(F.relu(self.linear1(x))))
+
+
+class PositionalEncoding(torch.nn.Module):
+ def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False):
+ super(PositionalEncoding, self).__init__()
+ self.d_model = d_model
+ self.reverse = reverse
+ self.xscale = math.sqrt(self.d_model)
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
+ self.pe = None
+ self.extend_pe(torch.tensor(0.0).expand(1, max_len))
+
+ def extend_pe(self, x):
+ if self.pe is not None:
+ if self.pe.size(1) >= x.size(1):
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
+ return
+ pe = torch.zeros(x.size(1), self.d_model)
+ if self.reverse:
+ position = torch.arange(
+ x.size(1) - 1, -1, -1.0, dtype=torch.float32
+ ).unsqueeze(1)
+ else:
+ position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
+ div_term = torch.exp(
+ torch.arange(0, self.d_model, 2, dtype=torch.float32)
+ * -(math.log(10000.0) / self.d_model)
+ )
+ pe[:, 0::2] = torch.sin(position * div_term)
+ pe[:, 1::2] = torch.cos(position * div_term)
+ pe = pe.unsqueeze(0)
+ self.pe = pe.to(device=x.device, dtype=x.dtype)
+
+ def forward(self, x: torch.Tensor):
+ self.extend_pe(x)
+ x = x * self.xscale + self.pe[:, : x.size(1)]
+ return self.dropout(x)
+
+
+class TransformerEncoder(nn.Module):
+ def __init__(self, idim, n_layers, n_units,
+ e_units=2048, h=8, dropout_rate=0.1, use_pos_emb=False):
+ super(TransformerEncoder, self).__init__()
+ self.lnorm_in = nn.LayerNorm(n_units)
+ self.n_layers = n_layers
+ self.dropout = nn.Dropout(dropout_rate)
+ for i in range(n_layers):
+ setattr(self, '{}{:d}'.format("lnorm1_", i),
+ nn.LayerNorm(n_units))
+ setattr(self, '{}{:d}'.format("self_att_", i),
+ MultiHeadSelfAttention(n_units, h))
+ setattr(self, '{}{:d}'.format("lnorm2_", i),
+ nn.LayerNorm(n_units))
+ setattr(self, '{}{:d}'.format("ff_", i),
+ PositionwiseFeedForward(n_units, e_units, dropout_rate))
+ self.lnorm_out = nn.LayerNorm(n_units)
+ if use_pos_emb:
+ self.pos_enc = torch.nn.Sequential(
+ torch.nn.Linear(idim, n_units),
+ torch.nn.LayerNorm(n_units),
+ torch.nn.Dropout(dropout_rate),
+ torch.nn.ReLU(),
+ PositionalEncoding(n_units, dropout_rate),
+ )
+ else:
+ self.linear_in = nn.Linear(idim, n_units)
+ self.pos_enc = None
+
+ def __call__(self, x, x_mask=None):
+ BT_size = x.shape[0] * x.shape[1]
+ if self.pos_enc is not None:
+ e = self.pos_enc(x)
+ e = e.view(BT_size, -1)
+ else:
+ e = self.linear_in(x.reshape(BT_size, -1))
+ for i in range(self.n_layers):
+ e = getattr(self, '{}{:d}'.format("lnorm1_", i))(e)
+ s = getattr(self, '{}{:d}'.format("self_att_", i))(e, x.shape[0], x_mask)
+ e = e + self.dropout(s)
+ e = getattr(self, '{}{:d}'.format("lnorm2_", i))(e)
+ s = getattr(self, '{}{:d}'.format("ff_", i))(e)
+ e = e + self.dropout(s)
+ return self.lnorm_out(e)
diff --git a/funasr/modules/eend_ola/encoder_decoder_attractor.py b/funasr/modules/eend_ola/encoder_decoder_attractor.py
new file mode 100644
index 0000000..db01b00
--- /dev/null
+++ b/funasr/modules/eend_ola/encoder_decoder_attractor.py
@@ -0,0 +1,50 @@
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+
+class EncoderDecoderAttractor(nn.Module):
+
+ def __init__(self, n_units, encoder_dropout=0.1, decoder_dropout=0.1):
+ super(EncoderDecoderAttractor, self).__init__()
+ self.enc0_dropout = nn.Dropout(encoder_dropout)
+ self.encoder = nn.LSTM(n_units, n_units, 1, batch_first=True, dropout=encoder_dropout)
+ self.dec0_dropout = nn.Dropout(decoder_dropout)
+ self.decoder = nn.LSTM(n_units, n_units, 1, batch_first=True, dropout=decoder_dropout)
+ self.counter = nn.Linear(n_units, 1)
+ self.n_units = n_units
+
+ def forward_core(self, xs, zeros):
+ ilens = torch.from_numpy(np.array([x.shape[0] for x in xs])).to(torch.float32).to(xs[0].device)
+ xs = [self.enc0_dropout(x) for x in xs]
+ xs = nn.utils.rnn.pad_sequence(xs, batch_first=True, padding_value=-1)
+ xs = nn.utils.rnn.pack_padded_sequence(xs, ilens, batch_first=True, enforce_sorted=False)
+ _, (hx, cx) = self.encoder(xs)
+ zlens = torch.from_numpy(np.array([z.shape[0] for z in zeros])).to(torch.float32).to(zeros[0].device)
+ max_zlen = torch.max(zlens).to(torch.int).item()
+ zeros = [self.enc0_dropout(z) for z in zeros]
+ zeros = nn.utils.rnn.pad_sequence(zeros, batch_first=True, padding_value=-1)
+ zeros = nn.utils.rnn.pack_padded_sequence(zeros, zlens, batch_first=True, enforce_sorted=False)
+ attractors, (_, _) = self.decoder(zeros, (hx, cx))
+ attractors = nn.utils.rnn.pad_packed_sequence(attractors, batch_first=True, padding_value=-1,
+ total_length=max_zlen)[0]
+ attractors = [att[:zlens[i].to(torch.int).item()] for i, att in enumerate(attractors)]
+ return attractors
+
+ def forward(self, xs, n_speakers):
+ zeros = [torch.zeros(n_spk + 1, self.n_units).to(torch.float32).to(xs[0].device) for n_spk in n_speakers]
+ attractors = self.forward_core(xs, zeros)
+ labels = torch.cat([torch.from_numpy(np.array([[1] * n_spk + [0]], np.float32)) for n_spk in n_speakers], dim=1)
+ labels = labels.to(xs[0].device)
+ logit = torch.cat([self.counter(att).view(-1, n_spk + 1) for att, n_spk in zip(attractors, n_speakers)], dim=1)
+ loss = F.binary_cross_entropy(torch.sigmoid(logit), labels)
+
+ attractors = [att[slice(0, att.shape[0] - 1)] for att in attractors]
+ return loss, attractors
+
+ def estimate(self, xs, max_n_speakers=15):
+ zeros = [torch.zeros(max_n_speakers, self.n_units).to(torch.float32).to(xs[0].device) for _ in xs]
+ attractors = self.forward_core(xs, zeros)
+ probs = [torch.sigmoid(torch.flatten(self.counter(att))) for att in attractors]
+ return attractors, probs
\ No newline at end of file
diff --git a/funasr/modules/eend_ola/utils/losses.py b/funasr/modules/eend_ola/utils/losses.py
new file mode 100644
index 0000000..af0181d
--- /dev/null
+++ b/funasr/modules/eend_ola/utils/losses.py
@@ -0,0 +1,67 @@
+import numpy as np
+import torch
+import torch.nn.functional as F
+from itertools import permutations
+from torch import nn
+
+
+def standard_loss(ys, ts, label_delay=0):
+ losses = [F.binary_cross_entropy(torch.sigmoid(y), t) * len(y) for y, t in zip(ys, ts)]
+ loss = torch.sum(torch.stack(losses))
+ n_frames = torch.from_numpy(np.array(np.sum([t.shape[0] for t in ts]))).to(torch.float32).to(ys[0].device)
+ loss = loss / n_frames
+ return loss
+
+
+def batch_pit_n_speaker_loss(ys, ts, n_speakers_list):
+ max_n_speakers = ts[0].shape[1]
+ olens = [y.shape[0] for y in ys]
+ ys = nn.utils.rnn.pad_sequence(ys, batch_first=True, padding_value=-1)
+ ys_mask = [torch.ones(olen).to(ys.device) for olen in olens]
+ ys_mask = torch.nn.utils.rnn.pad_sequence(ys_mask, batch_first=True, padding_value=0).unsqueeze(-1)
+
+ losses = []
+ for shift in range(max_n_speakers):
+ ts_roll = [torch.roll(t, -shift, dims=1) for t in ts]
+ ts_roll = nn.utils.rnn.pad_sequence(ts_roll, batch_first=True, padding_value=-1)
+ loss = F.binary_cross_entropy(torch.sigmoid(ys), ts_roll, reduction='none')
+ if ys_mask is not None:
+ loss = loss * ys_mask
+ loss = torch.sum(loss, dim=1)
+ losses.append(loss)
+ losses = torch.stack(losses, dim=2)
+
+ perms = np.array(list(permutations(range(max_n_speakers)))).astype(np.float32)
+ perms = torch.from_numpy(perms).to(losses.device)
+ y_ind = torch.arange(max_n_speakers, dtype=torch.float32, device=losses.device)
+ t_inds = torch.fmod(perms - y_ind, max_n_speakers).to(torch.long)
+
+ losses_perm = []
+ for t_ind in t_inds:
+ losses_perm.append(
+ torch.mean(losses[:, y_ind.to(torch.long), t_ind], dim=1))
+ losses_perm = torch.stack(losses_perm, dim=1)
+
+ def select_perm_indices(num, max_num):
+ perms = list(permutations(range(max_num)))
+ sub_perms = list(permutations(range(num)))
+ return [
+ [x[:num] for x in perms].index(perm)
+ for perm in sub_perms]
+
+ masks = torch.full_like(losses_perm, device=losses.device, fill_value=float('inf'))
+ for i, t in enumerate(ts):
+ n_speakers = n_speakers_list[i]
+ indices = select_perm_indices(n_speakers, max_n_speakers)
+ masks[i, indices] = 0
+ losses_perm += masks
+
+ min_loss = torch.sum(torch.min(losses_perm, dim=1)[0])
+ n_frames = torch.from_numpy(np.array(np.sum([t.shape[0] for t in ts]))).to(losses.device)
+ min_loss = min_loss / n_frames
+
+ min_indices = torch.argmin(losses_perm, dim=1)
+ labels_perm = [t[:, perms[idx].to(torch.long)] for t, idx in zip(ts, min_indices)]
+ labels_perm = [t[:, :n_speakers] for t, n_speakers in zip(labels_perm, n_speakers_list)]
+
+ return min_loss, labels_perm
diff --git a/funasr/modules/eend_ola/utils/power.py b/funasr/modules/eend_ola/utils/power.py
new file mode 100644
index 0000000..7144e24
--- /dev/null
+++ b/funasr/modules/eend_ola/utils/power.py
@@ -0,0 +1,95 @@
+import numpy as np
+import torch
+import torch.multiprocessing
+import torch.nn.functional as F
+from itertools import combinations
+from itertools import permutations
+
+
+def generate_mapping_dict(max_speaker_num=6, max_olp_speaker_num=3):
+ all_kinds = []
+ all_kinds.append(0)
+ for i in range(max_olp_speaker_num):
+ selected_num = i + 1
+ coms = np.array(list(combinations(np.arange(max_speaker_num), selected_num)))
+ for com in coms:
+ tmp = np.zeros(max_speaker_num)
+ tmp[com] = 1
+ item = int(raw_dec_trans(tmp.reshape(1, -1), max_speaker_num)[0])
+ all_kinds.append(item)
+ all_kinds_order = sorted(all_kinds)
+
+ mapping_dict = {}
+ mapping_dict['dec2label'] = {}
+ mapping_dict['label2dec'] = {}
+ for i in range(len(all_kinds_order)):
+ dec = all_kinds_order[i]
+ mapping_dict['dec2label'][dec] = i
+ mapping_dict['label2dec'][i] = dec
+ oov_id = len(all_kinds_order)
+ mapping_dict['oov'] = oov_id
+ return mapping_dict
+
+
+def raw_dec_trans(x, max_speaker_num):
+ num_list = []
+ for i in range(max_speaker_num):
+ num_list.append(x[:, i])
+ base = 1
+ T = x.shape[0]
+ res = np.zeros((T))
+ for num in num_list:
+ res += num * base
+ base = base * 2
+ return res
+
+
+def mapping_func(num, mapping_dict):
+ if num in mapping_dict['dec2label'].keys():
+ label = mapping_dict['dec2label'][num]
+ else:
+ label = mapping_dict['oov']
+ return label
+
+
+def dec_trans(x, max_speaker_num, mapping_dict):
+ num_list = []
+ for i in range(max_speaker_num):
+ num_list.append(x[:, i])
+ base = 1
+ T = x.shape[0]
+ res = np.zeros((T))
+ for num in num_list:
+ res += num * base
+ base = base * 2
+ res = np.array([mapping_func(i, mapping_dict) for i in res])
+ return res
+
+
+def create_powerlabel(label, mapping_dict, max_speaker_num=6, max_olp_speaker_num=3):
+ T, C = label.shape
+ padding_label = np.zeros((T, max_speaker_num))
+ padding_label[:, :C] = label
+ out_label = dec_trans(padding_label, max_speaker_num, mapping_dict)
+ out_label = torch.from_numpy(out_label)
+ return out_label
+
+
+def generate_perm_pse(label, n_speaker, mapping_dict, max_speaker_num, max_olp_speaker_num=3):
+ perms = np.array(list(permutations(range(n_speaker)))).astype(np.float32)
+ perms = torch.from_numpy(perms).to(label.device).to(torch.int64)
+ perm_labels = [label[:, perm] for perm in perms]
+ perm_pse_labels = [create_powerlabel(perm_label.cpu().numpy(), mapping_dict, max_speaker_num).
+ to(perm_label.device, non_blocking=True) for perm_label in perm_labels]
+ return perm_labels, perm_pse_labels
+
+
+def generate_min_pse(label, n_speaker, mapping_dict, max_speaker_num, pse_logit, max_olp_speaker_num=3):
+ perm_labels, perm_pse_labels = generate_perm_pse(label, n_speaker, mapping_dict, max_speaker_num,
+ max_olp_speaker_num=max_olp_speaker_num)
+ losses = [F.cross_entropy(input=pse_logit, target=perm_pse_label.to(torch.long)) * len(pse_logit)
+ for perm_pse_label in perm_pse_labels]
+ loss = torch.stack(losses)
+ min_index = torch.argmin(loss)
+ selected_perm_label, selected_pse_label = perm_labels[min_index], perm_pse_labels[min_index]
+ return selected_perm_label, selected_pse_label
diff --git a/funasr/modules/eend_ola/utils/report.py b/funasr/modules/eend_ola/utils/report.py
new file mode 100644
index 0000000..bfccedf
--- /dev/null
+++ b/funasr/modules/eend_ola/utils/report.py
@@ -0,0 +1,159 @@
+import copy
+import numpy as np
+import time
+import torch
+from eend.utils.power import create_powerlabel
+from itertools import combinations
+
+metrics = [
+ ('diarization_error', 'speaker_scored', 'DER'),
+ ('speech_miss', 'speech_scored', 'SAD_MR'),
+ ('speech_falarm', 'speech_scored', 'SAD_FR'),
+ ('speaker_miss', 'speaker_scored', 'MI'),
+ ('speaker_falarm', 'speaker_scored', 'FA'),
+ ('speaker_error', 'speaker_scored', 'CF'),
+ ('correct', 'frames', 'accuracy')
+]
+
+
+def recover_prediction(y, n_speaker):
+ if n_speaker <= 1:
+ return y
+ elif n_speaker == 2:
+ com_index = torch.from_numpy(
+ np.array(list(combinations(np.arange(n_speaker), 2)))).to(
+ y.dtype)
+ num_coms = com_index.shape[0]
+ y_single = y[:, :-num_coms]
+ y_olp = y[:, -num_coms:]
+ olp_map_index = torch.where(y_olp > 0.5)
+ olp_map_index = torch.stack(olp_map_index, dim=1)
+ com_map_index = com_index[olp_map_index[:, -1]]
+ speaker_map_index = torch.from_numpy(np.array(com_map_index)).view(-1).to(torch.int64)
+ frame_map_index = olp_map_index[:, 0][:, None].repeat([1, 2]).view(-1).to(
+ torch.int64)
+ y_single[frame_map_index] = 0
+ y_single[frame_map_index, speaker_map_index] = 1
+ return y_single
+ else:
+ olp2_com_index = torch.from_numpy(np.array(list(combinations(np.arange(n_speaker), 2)))).to(y.dtype)
+ olp2_num_coms = olp2_com_index.shape[0]
+ olp3_com_index = torch.from_numpy(np.array(list(combinations(np.arange(n_speaker), 3)))).to(y.dtype)
+ olp3_num_coms = olp3_com_index.shape[0]
+ y_single = y[:, :n_speaker]
+ y_olp2 = y[:, n_speaker:n_speaker + olp2_num_coms]
+ y_olp3 = y[:, -olp3_num_coms:]
+
+ olp3_map_index = torch.where(y_olp3 > 0.5)
+ olp3_map_index = torch.stack(olp3_map_index, dim=1)
+ olp3_com_map_index = olp3_com_index[olp3_map_index[:, -1]]
+ olp3_speaker_map_index = torch.from_numpy(np.array(olp3_com_map_index)).view(-1).to(torch.int64)
+ olp3_frame_map_index = olp3_map_index[:, 0][:, None].repeat([1, 3]).view(-1).to(torch.int64)
+ y_single[olp3_frame_map_index] = 0
+ y_single[olp3_frame_map_index, olp3_speaker_map_index] = 1
+ y_olp2[olp3_frame_map_index] = 0
+
+ olp2_map_index = torch.where(y_olp2 > 0.5)
+ olp2_map_index = torch.stack(olp2_map_index, dim=1)
+ olp2_com_map_index = olp2_com_index[olp2_map_index[:, -1]]
+ olp2_speaker_map_index = torch.from_numpy(np.array(olp2_com_map_index)).view(-1).to(torch.int64)
+ olp2_frame_map_index = olp2_map_index[:, 0][:, None].repeat([1, 2]).view(-1).to(torch.int64)
+ y_single[olp2_frame_map_index] = 0
+ y_single[olp2_frame_map_index, olp2_speaker_map_index] = 1
+ return y_single
+
+
+class PowerReporter():
+ def __init__(self, valid_data_loader, mapping_dict, max_n_speaker):
+ valid_data_loader_cp = copy.deepcopy(valid_data_loader)
+ self.valid_data_loader = valid_data_loader_cp
+ del valid_data_loader
+ self.mapping_dict = mapping_dict
+ self.max_n_speaker = max_n_speaker
+
+ def report(self, model, eidx, device):
+ self.report_val(model, eidx, device)
+
+ def report_val(self, model, eidx, device):
+ model.eval()
+ ud_valid_start = time.time()
+ valid_res, valid_loss, stats_keys, vad_valid_accuracy = self.report_core(model, self.valid_data_loader, device)
+
+ # Epoch Display
+ valid_der = valid_res['diarization_error'] / valid_res['speaker_scored']
+ valid_accuracy = valid_res['correct'].to(torch.float32) / valid_res['frames'] * 100
+ vad_valid_accuracy = vad_valid_accuracy * 100
+ print('Epoch ', eidx + 1, 'Valid Loss ', valid_loss, 'Valid_DER %.5f' % valid_der,
+ 'Valid_Accuracy %.5f%% ' % valid_accuracy, 'VAD_Valid_Accuracy %.5f%% ' % vad_valid_accuracy)
+ ud_valid = (time.time() - ud_valid_start) / 60.
+ print('Valid cost time ... ', ud_valid)
+
+ def inv_mapping_func(self, label, mapping_dict):
+ if not isinstance(label, int):
+ label = int(label)
+ if label in mapping_dict['label2dec'].keys():
+ num = mapping_dict['label2dec'][label]
+ else:
+ num = -1
+ return num
+
+ def report_core(self, model, data_loader, device):
+ res = {}
+ for item in metrics:
+ res[item[0]] = 0.
+ res[item[1]] = 0.
+ with torch.no_grad():
+ loss_s = 0.
+ uidx = 0
+ for xs, ts, orders in data_loader:
+ xs = [x.to(device) for x in xs]
+ ts = [t.to(device) for t in ts]
+ orders = [o.to(device) for o in orders]
+ loss, pit_loss, mpit_loss, att_loss, ys, logits, labels, attractors = model(xs, ts, orders)
+ loss_s += loss.item()
+ uidx += 1
+
+ for logit, t, att in zip(logits, labels, attractors):
+ pred = torch.argmax(torch.softmax(logit, dim=-1), dim=-1) # (T, )
+ oov_index = torch.where(pred == self.mapping_dict['oov'])[0]
+ for i in oov_index:
+ if i > 0:
+ pred[i] = pred[i - 1]
+ else:
+ pred[i] = 0
+ pred = [self.inv_mapping_func(i, self.mapping_dict) for i in pred]
+ decisions = [bin(num)[2:].zfill(self.max_n_speaker)[::-1] for num in pred]
+ decisions = torch.from_numpy(
+ np.stack([np.array([int(i) for i in dec]) for dec in decisions], axis=0)).to(att.device).to(
+ torch.float32)
+ decisions = decisions[:, :att.shape[0]]
+
+ stats = self.calc_diarization_error(decisions, t)
+ res['speaker_scored'] += stats['speaker_scored']
+ res['speech_scored'] += stats['speech_scored']
+ res['frames'] += stats['frames']
+ for item in metrics:
+ res[item[0]] += stats[item[0]]
+ loss_s /= uidx
+ vad_acc = 0
+
+ return res, loss_s, stats.keys(), vad_acc
+
+ def calc_diarization_error(self, decisions, label, label_delay=0):
+ label = label[:len(label) - label_delay, ...]
+ n_ref = torch.sum(label, dim=-1)
+ n_sys = torch.sum(decisions, dim=-1)
+ res = {}
+ res['speech_scored'] = torch.sum(n_ref > 0)
+ res['speech_miss'] = torch.sum((n_ref > 0) & (n_sys == 0))
+ res['speech_falarm'] = torch.sum((n_ref == 0) & (n_sys > 0))
+ res['speaker_scored'] = torch.sum(n_ref)
+ res['speaker_miss'] = torch.sum(torch.max(n_ref - n_sys, torch.zeros_like(n_ref)))
+ res['speaker_falarm'] = torch.sum(torch.max(n_sys - n_ref, torch.zeros_like(n_ref)))
+ n_map = torch.sum(((label == 1) & (decisions == 1)), dim=-1).to(torch.float32)
+ res['speaker_error'] = torch.sum(torch.min(n_ref, n_sys) - n_map)
+ res['correct'] = torch.sum(label == decisions) / label.shape[1]
+ res['diarization_error'] = (
+ res['speaker_miss'] + res['speaker_falarm'] + res['speaker_error'])
+ res['frames'] = len(label)
+ return res
--
Gitblit v1.9.1