| | |
| | | from funasr.build_utils.build_diar_model import class_choices_list |
| | | for class_choices in class_choices_list: |
| | | class_choices.add_arguments(task_parser) |
| | | task_parser.add_argument( |
| | | "--input_size", |
| | | type=int_or_none, |
| | | default=None, |
| | | help="The number of input dimension of the feature", |
| | | ) |
| | | |
| | | elif args.task_name == "sv": |
| | | from funasr.build_utils.build_sv_model import class_choices_list |
| | |
| | | |
| | | def build_dataloader(args): |
| | | if args.dataset_type == "small": |
| | | train_iter_factory = SequenceIterFactory(args, mode="train") |
| | | valid_iter_factory = SequenceIterFactory(args, mode="valid") |
| | | if args.task_name == "diar" and args.model == "eend_ola": |
| | | from funasr.modules.eend_ola.eend_ola_dataloader import EENDOLADataLoader |
| | | train_iter_factory = EENDOLADataLoader( |
| | | data_file=args.train_data_path_and_name_and_type[0][0], |
| | | batch_size=args.dataset_conf["batch_conf"]["batch_size"], |
| | | num_workers=args.dataset_conf["num_workers"], |
| | | shuffle=True) |
| | | valid_iter_factory = EENDOLADataLoader( |
| | | data_file=args.valid_data_path_and_name_and_type[0][0], |
| | | batch_size=args.dataset_conf["batch_conf"]["batch_size"], |
| | | num_workers=0, |
| | | shuffle=False) |
| | | else: |
| | | train_iter_factory = SequenceIterFactory(args, mode="train") |
| | | valid_iter_factory = SequenceIterFactory(args, mode="valid") |
| | | elif args.dataset_type == "large": |
| | | train_iter_factory = LargeDataLoader(args, mode="train") |
| | | valid_iter_factory = LargeDataLoader(args, mode="valid") |
| | |
| | | frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf) |
| | | else: |
| | | frontend = frontend_class(**args.frontend_conf) |
| | | input_size = frontend.output_size() |
| | | else: |
| | | args.frontend = None |
| | | args.frontend_conf = {} |
| | | frontend = None |
| | | input_size = args.input_size |
| | | |
| | | # encoder |
| | | encoder_class = encoder_choices.get_class(args.encoder) |
| | | encoder = encoder_class(input_size=input_size, **args.encoder_conf) |
| | | encoder = encoder_class(**args.encoder_conf) |
| | | |
| | | if args.model == "sond": |
| | | # data augmentation for spectrogram |
| | |
| | | **args.model_conf, |
| | | ) |
| | | |
| | | elif args.model_name == "eend_ola": |
| | | elif args.model == "eend_ola": |
| | | # encoder-decoder attractor |
| | | encoder_decoder_attractor_class = encoder_decoder_attractor_choices.get_class(args.encoder_decoder_attractor) |
| | | encoder_decoder_attractor = encoder_decoder_attractor_class(**args.encoder_decoder_attractor_conf) |
| | |
| | | data_path_and_name_and_type, |
| | | preprocess=preprocess_fn, |
| | | dest_sample_rate=dest_sample_rate, |
| | | speed_perturb=args.speed_perturb if mode=="train" else None, |
| | | speed_perturb=args.speed_perturb if mode == "train" else None, |
| | | ) |
| | | |
| | | # sampler |
| | |
| | | args.max_update = len(bs_list) * args.max_epoch |
| | | logging.info("Max update: {}".format(args.max_update)) |
| | | |
| | | if args.distributed and mode=="train": |
| | | if args.distributed and mode == "train": |
| | | world_size = torch.distributed.get_world_size() |
| | | rank = torch.distributed.get_rank() |
| | | for batch in batches: |
| | |
| | | # Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved. |
| | | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) |
| | | |
| | | from contextlib import contextmanager |
| | | from distutils.version import LooseVersion |
| | | from typing import Dict |
| | | from typing import Tuple |
| | | from typing import Dict, List, Tuple, Optional |
| | | |
| | | import numpy as np |
| | | import torch |
| | | import torch.nn as nn |
| | | import torch.nn.functional as F |
| | | from typeguard import check_argument_types |
| | | |
| | | from funasr.models.base_model import FunASRModel |
| | | from funasr.models.frontend.wav_frontend import WavFrontendMel23 |
| | | from funasr.modules.eend_ola.encoder import EENDOLATransformerEncoder |
| | | from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor |
| | | from funasr.modules.eend_ola.utils.losses import fast_batch_pit_n_speaker_loss, standard_loss, cal_power_loss |
| | | from funasr.modules.eend_ola.utils.power import create_powerlabel |
| | | from funasr.modules.eend_ola.utils.power import generate_mapping_dict |
| | | from funasr.torch_utils.device_funcs import force_gatherable |
| | | from funasr.models.base_model import FunASRModel |
| | | |
| | | if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): |
| | | pass |
| | |
| | | return att |
| | | |
| | | |
| | | def pad_labels(ts, out_size): |
| | | for i, t in enumerate(ts): |
| | | if t.shape[1] < out_size: |
| | | ts[i] = F.pad( |
| | | t, |
| | | (0, out_size - t.shape[1], 0, 0), |
| | | mode='constant', |
| | | value=0. |
| | | ) |
| | | return ts |
| | | |
| | | |
| | | def pad_results(ys, out_size): |
| | | ys_padded = [] |
| | | for i, y in enumerate(ys): |
| | | if y.shape[1] < out_size: |
| | | ys_padded.append( |
| | | torch.cat([y, torch.zeros(y.shape[0], out_size - y.shape[1]).to(torch.float32).to(y.device)], dim=1)) |
| | | else: |
| | | ys_padded.append(y) |
| | | return ys_padded |
| | | |
| | | |
| | | class DiarEENDOLAModel(FunASRModel): |
| | | """EEND-OLA diarization model""" |
| | | |
| | | def __init__( |
| | | self, |
| | | frontend: WavFrontendMel23, |
| | | frontend: Optional[WavFrontendMel23], |
| | | encoder: EENDOLATransformerEncoder, |
| | | encoder_decoder_attractor: EncoderDecoderAttractor, |
| | | n_units: int = 256, |
| | |
| | | mapping_dict=None, |
| | | **kwargs, |
| | | ): |
| | | assert check_argument_types() |
| | | |
| | | super().__init__() |
| | | self.frontend = frontend |
| | | self.enc = encoder |
| | | self.eda = encoder_decoder_attractor |
| | | self.encoder_decoder_attractor = encoder_decoder_attractor |
| | | self.attractor_loss_weight = attractor_loss_weight |
| | | self.max_n_speaker = max_n_speaker |
| | | if mapping_dict is None: |
| | |
| | | def forward_post_net(self, logits, ilens): |
| | | maxlen = torch.max(ilens).to(torch.int).item() |
| | | logits = nn.utils.rnn.pad_sequence(logits, batch_first=True, padding_value=-1) |
| | | logits = nn.utils.rnn.pack_padded_sequence(logits, ilens.cpu().to(torch.int64), batch_first=True, enforce_sorted=False) |
| | | logits = nn.utils.rnn.pack_padded_sequence(logits, ilens.cpu().to(torch.int64), batch_first=True, |
| | | enforce_sorted=False) |
| | | outputs, (_, _) = self.postnet(logits) |
| | | outputs = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True, padding_value=-1, total_length=maxlen)[0] |
| | | outputs = [output[:ilens[i].to(torch.int).item()] for i, output in enumerate(outputs)] |
| | |
| | | |
| | | def forward( |
| | | self, |
| | | speech: torch.Tensor, |
| | | speech_lengths: torch.Tensor, |
| | | text: torch.Tensor, |
| | | text_lengths: torch.Tensor, |
| | | speech: List[torch.Tensor], |
| | | speech_lengths: torch.Tensor, # num_frames of each sample |
| | | speaker_labels: List[torch.Tensor], |
| | | speaker_labels_lengths: torch.Tensor, # num_speakers of each sample |
| | | orders: torch.Tensor, |
| | | ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: |
| | | """Frontend + Encoder + Decoder + Calc loss |
| | | Args: |
| | | speech: (Batch, Length, ...) |
| | | speech_lengths: (Batch, ) |
| | | text: (Batch, Length) |
| | | text_lengths: (Batch,) |
| | | """ |
| | | assert text_lengths.dim() == 1, text_lengths.shape |
| | | |
| | | # Check that batch_size is unified |
| | | assert ( |
| | | speech.shape[0] |
| | | == speech_lengths.shape[0] |
| | | == text.shape[0] |
| | | == text_lengths.shape[0] |
| | | ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape) |
| | | batch_size = speech.shape[0] |
| | | len(speech) |
| | | == len(speech_lengths) |
| | | == len(speaker_labels) |
| | | == len(speaker_labels_lengths) |
| | | ), (len(speech), len(speech_lengths), len(speaker_labels), len(speaker_labels_lengths)) |
| | | batch_size = len(speech) |
| | | |
| | | # for data-parallel |
| | | text = text[:, : text_lengths.max()] |
| | | # Encoder |
| | | speech = [s[:s_len] for s, s_len in zip(speech, speech_lengths)] |
| | | encoder_out = self.forward_encoder(speech, speech_lengths) |
| | | |
| | | # 1. Encoder |
| | | encoder_out, encoder_out_lens = self.enc(speech, speech_lengths) |
| | | intermediate_outs = None |
| | | if isinstance(encoder_out, tuple): |
| | | intermediate_outs = encoder_out[1] |
| | | encoder_out = encoder_out[0] |
| | | # Encoder-decoder attractor |
| | | attractor_loss, attractors = self.encoder_decoder_attractor([e[order] for e, order in zip(encoder_out, orders)], |
| | | speaker_labels_lengths) |
| | | speaker_logits = [torch.matmul(e, att.permute(1, 0)) for e, att in zip(encoder_out, attractors)] |
| | | |
| | | loss_att, acc_att, cer_att, wer_att = None, None, None, None |
| | | loss_ctc, cer_ctc = None, None |
| | | # pit loss |
| | | pit_speaker_labels = fast_batch_pit_n_speaker_loss(speaker_logits, speaker_labels) |
| | | pit_loss = standard_loss(speaker_logits, pit_speaker_labels) |
| | | |
| | | # pse loss |
| | | with torch.no_grad(): |
| | | power_ts = [create_powerlabel(label.cpu().numpy(), self.mapping_dict, self.max_n_speaker). |
| | | to(encoder_out[0].device, non_blocking=True) for label in pit_speaker_labels] |
| | | pad_attractors = [pad_attractor(att, self.max_n_speaker) for att in attractors] |
| | | pse_speaker_logits = [torch.matmul(e, pad_att.permute(1, 0)) for e, pad_att in zip(encoder_out, pad_attractors)] |
| | | pse_speaker_logits = self.forward_post_net(pse_speaker_logits, speech_lengths) |
| | | pse_loss = cal_power_loss(pse_speaker_logits, power_ts) |
| | | |
| | | loss = pse_loss + pit_loss + self.attractor_loss_weight * attractor_loss |
| | | |
| | | stats = dict() |
| | | |
| | | # 1. CTC branch |
| | | if self.ctc_weight != 0.0: |
| | | loss_ctc, cer_ctc = self._calc_ctc_loss( |
| | | encoder_out, encoder_out_lens, text, text_lengths |
| | | ) |
| | | |
| | | # Collect CTC branch stats |
| | | stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None |
| | | stats["cer_ctc"] = cer_ctc |
| | | |
| | | # Intermediate CTC (optional) |
| | | loss_interctc = 0.0 |
| | | if self.interctc_weight != 0.0 and intermediate_outs is not None: |
| | | for layer_idx, intermediate_out in intermediate_outs: |
| | | # we assume intermediate_out has the same length & padding |
| | | # as those of encoder_out |
| | | loss_ic, cer_ic = self._calc_ctc_loss( |
| | | intermediate_out, encoder_out_lens, text, text_lengths |
| | | ) |
| | | loss_interctc = loss_interctc + loss_ic |
| | | |
| | | # Collect Intermedaite CTC stats |
| | | stats["loss_interctc_layer{}".format(layer_idx)] = ( |
| | | loss_ic.detach() if loss_ic is not None else None |
| | | ) |
| | | stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic |
| | | |
| | | loss_interctc = loss_interctc / len(intermediate_outs) |
| | | |
| | | # calculate whole encoder loss |
| | | loss_ctc = ( |
| | | 1 - self.interctc_weight |
| | | ) * loss_ctc + self.interctc_weight * loss_interctc |
| | | |
| | | # 2b. Attention decoder branch |
| | | if self.ctc_weight != 1.0: |
| | | loss_att, acc_att, cer_att, wer_att = self._calc_att_loss( |
| | | encoder_out, encoder_out_lens, text, text_lengths |
| | | ) |
| | | |
| | | # 3. CTC-Att loss definition |
| | | if self.ctc_weight == 0.0: |
| | | loss = loss_att |
| | | elif self.ctc_weight == 1.0: |
| | | loss = loss_ctc |
| | | else: |
| | | loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att |
| | | |
| | | # Collect Attn branch stats |
| | | stats["loss_att"] = loss_att.detach() if loss_att is not None else None |
| | | stats["acc"] = acc_att |
| | | stats["cer"] = cer_att |
| | | stats["wer"] = wer_att |
| | | stats["pse_loss"] = pse_loss.detach() |
| | | stats["pit_loss"] = pit_loss.detach() |
| | | stats["attractor_loss"] = attractor_loss.detach() |
| | | stats["batch_size"] = batch_size |
| | | |
| | | # Collect total loss stats |
| | | stats["loss"] = torch.clone(loss.detach()) |
| | |
| | | orders = [np.arange(e.shape[0]) for e in emb] |
| | | for order in orders: |
| | | np.random.shuffle(order) |
| | | attractors, probs = self.eda.estimate( |
| | | attractors, probs = self.encoder_decoder_attractor.estimate( |
| | | [e[torch.from_numpy(order).to(torch.long).to(speech[0].device)] for e, order in zip(emb, orders)]) |
| | | else: |
| | | attractors, probs = self.eda.estimate(emb) |
| | | attractors, probs = self.encoder_decoder_attractor.estimate(emb) |
| | | attractors_active = [] |
| | | for p, att, e in zip(probs, attractors, emb): |
| | | if n_speakers and n_speakers >= 0: |
| New file |
| | |
| | | import logging |
| | | |
| | | import kaldiio |
| | | import numpy as np |
| | | import torch |
| | | from torch.utils.data import DataLoader |
| | | from torch.utils.data import Dataset |
| | | |
| | | |
| | | def custom_collate(batch): |
| | | keys, speech, speaker_labels, orders = zip(*batch) |
| | | speech = [torch.from_numpy(np.copy(sph)).to(torch.float32) for sph in speech] |
| | | speaker_labels = [torch.from_numpy(np.copy(spk)).to(torch.float32) for spk in speaker_labels] |
| | | orders = [torch.from_numpy(np.copy(o)).to(torch.int64) for o in orders] |
| | | batch = dict(speech=speech, |
| | | speaker_labels=speaker_labels, |
| | | orders=orders) |
| | | |
| | | return keys, batch |
| | | |
| | | |
| | | class EENDOLADataset(Dataset): |
| | | def __init__( |
| | | self, |
| | | data_file, |
| | | ): |
| | | self.data_file = data_file |
| | | with open(data_file) as f: |
| | | lines = f.readlines() |
| | | self.samples = [line.strip().split() for line in lines] |
| | | logging.info("total samples: {}".format(len(self.samples))) |
| | | |
| | | def __len__(self): |
| | | return len(self.samples) |
| | | |
| | | def __getitem__(self, idx): |
| | | key, speech_path, speaker_label_path = self.samples[idx] |
| | | speech = kaldiio.load_mat(speech_path) |
| | | speaker_label = kaldiio.load_mat(speaker_label_path).reshape(speech.shape[0], -1) |
| | | |
| | | order = np.arange(speech.shape[0]) |
| | | np.random.shuffle(order) |
| | | |
| | | return key, speech, speaker_label, order |
| | | |
| | | |
| | | class EENDOLADataLoader(): |
| | | def __init__(self, data_file, batch_size, shuffle=True, num_workers=8): |
| | | dataset = EENDOLADataset(data_file) |
| | | self.data_loader = DataLoader(dataset, |
| | | batch_size=batch_size, |
| | | collate_fn=custom_collate, |
| | | shuffle=shuffle, |
| | | num_workers=num_workers) |
| | | |
| | | def build_iter(self, epoch): |
| | | return self.data_loader |
| | |
| | | dropout_rate: float = 0.1, |
| | | use_pos_emb: bool = False): |
| | | super(EENDOLATransformerEncoder, self).__init__() |
| | | self.linear_in = nn.Linear(idim, n_units) |
| | | self.lnorm_in = nn.LayerNorm(n_units) |
| | | self.n_layers = n_layers |
| | | self.dropout = nn.Dropout(dropout_rate) |
| | |
| | | setattr(self, '{}{:d}'.format("ff_", i), |
| | | PositionwiseFeedForward(n_units, e_units, dropout_rate)) |
| | | self.lnorm_out = nn.LayerNorm(n_units) |
| | | if use_pos_emb: |
| | | self.pos_enc = torch.nn.Sequential( |
| | | torch.nn.Linear(idim, n_units), |
| | | torch.nn.LayerNorm(n_units), |
| | | torch.nn.Dropout(dropout_rate), |
| | | torch.nn.ReLU(), |
| | | PositionalEncoding(n_units, dropout_rate), |
| | | ) |
| | | else: |
| | | self.linear_in = nn.Linear(idim, n_units) |
| | | self.pos_enc = None |
| | | |
| | | def __call__(self, x, x_mask=None): |
| | | BT_size = x.shape[0] * x.shape[1] |
| | | if self.pos_enc is not None: |
| | | e = self.pos_enc(x) |
| | | e = e.view(BT_size, -1) |
| | | else: |
| | | e = self.linear_in(x.reshape(BT_size, -1)) |
| | | e = self.linear_in(x.reshape(BT_size, -1)) |
| | | for i in range(self.n_layers): |
| | | e = getattr(self, '{}{:d}'.format("lnorm1_", i))(e) |
| | | s = getattr(self, '{}{:d}'.format("self_att_", i))(e, x.shape[0], x_mask) |
| | |
| | | e = getattr(self, '{}{:d}'.format("lnorm2_", i))(e) |
| | | s = getattr(self, '{}{:d}'.format("ff_", i))(e) |
| | | e = e + self.dropout(s) |
| | | return self.lnorm_out(e) |
| | | return self.lnorm_out(e) |
| | |
| | | import numpy as np |
| | | import torch |
| | | import torch.nn.functional as F |
| | | from itertools import permutations |
| | | from torch import nn |
| | | from scipy.optimize import linear_sum_assignment |
| | | |
| | | |
| | | def standard_loss(ys, ts, label_delay=0): |
| | | def standard_loss(ys, ts): |
| | | losses = [F.binary_cross_entropy(torch.sigmoid(y), t) * len(y) for y, t in zip(ys, ts)] |
| | | loss = torch.sum(torch.stack(losses)) |
| | | n_frames = torch.from_numpy(np.array(np.sum([t.shape[0] for t in ts]))).to(torch.float32).to(ys[0].device) |
| | |
| | | return loss |
| | | |
| | | |
| | | def batch_pit_n_speaker_loss(ys, ts, n_speakers_list): |
| | | max_n_speakers = ts[0].shape[1] |
| | | olens = [y.shape[0] for y in ys] |
| | | ys = nn.utils.rnn.pad_sequence(ys, batch_first=True, padding_value=-1) |
| | | ys_mask = [torch.ones(olen).to(ys.device) for olen in olens] |
| | | ys_mask = torch.nn.utils.rnn.pad_sequence(ys_mask, batch_first=True, padding_value=0).unsqueeze(-1) |
| | | def fast_batch_pit_n_speaker_loss(ys, ts): |
| | | with torch.no_grad(): |
| | | bs = len(ys) |
| | | indices = [] |
| | | for b in range(bs): |
| | | y = ys[b].transpose(0, 1) |
| | | t = ts[b].transpose(0, 1) |
| | | C, _ = t.shape |
| | | y = y[:, None, :].repeat(1, C, 1) |
| | | t = t[None, :, :].repeat(C, 1, 1) |
| | | bce_loss = F.binary_cross_entropy(torch.sigmoid(y), t, reduction="none").mean(-1) |
| | | C = bce_loss.cpu() |
| | | indices.append(linear_sum_assignment(C)) |
| | | labels_perm = [t[:, idx[1]] for t, idx in zip(ts, indices)] |
| | | |
| | | losses = [] |
| | | for shift in range(max_n_speakers): |
| | | ts_roll = [torch.roll(t, -shift, dims=1) for t in ts] |
| | | ts_roll = nn.utils.rnn.pad_sequence(ts_roll, batch_first=True, padding_value=-1) |
| | | loss = F.binary_cross_entropy(torch.sigmoid(ys), ts_roll, reduction='none') |
| | | if ys_mask is not None: |
| | | loss = loss * ys_mask |
| | | loss = torch.sum(loss, dim=1) |
| | | losses.append(loss) |
| | | losses = torch.stack(losses, dim=2) |
| | | return labels_perm |
| | | |
| | | perms = np.array(list(permutations(range(max_n_speakers)))).astype(np.float32) |
| | | perms = torch.from_numpy(perms).to(losses.device) |
| | | y_ind = torch.arange(max_n_speakers, dtype=torch.float32, device=losses.device) |
| | | t_inds = torch.fmod(perms - y_ind, max_n_speakers).to(torch.long) |
| | | |
| | | losses_perm = [] |
| | | for t_ind in t_inds: |
| | | losses_perm.append( |
| | | torch.mean(losses[:, y_ind.to(torch.long), t_ind], dim=1)) |
| | | losses_perm = torch.stack(losses_perm, dim=1) |
| | | |
| | | def select_perm_indices(num, max_num): |
| | | perms = list(permutations(range(max_num))) |
| | | sub_perms = list(permutations(range(num))) |
| | | return [ |
| | | [x[:num] for x in perms].index(perm) |
| | | for perm in sub_perms] |
| | | |
| | | masks = torch.full_like(losses_perm, device=losses.device, fill_value=float('inf')) |
| | | for i, t in enumerate(ts): |
| | | n_speakers = n_speakers_list[i] |
| | | indices = select_perm_indices(n_speakers, max_n_speakers) |
| | | masks[i, indices] = 0 |
| | | losses_perm += masks |
| | | |
| | | min_loss = torch.sum(torch.min(losses_perm, dim=1)[0]) |
| | | n_frames = torch.from_numpy(np.array(np.sum([t.shape[0] for t in ts]))).to(losses.device) |
| | | min_loss = min_loss / n_frames |
| | | |
| | | min_indices = torch.argmin(losses_perm, dim=1) |
| | | labels_perm = [t[:, perms[idx].to(torch.long)] for t, idx in zip(ts, min_indices)] |
| | | labels_perm = [t[:, :n_speakers] for t, n_speakers in zip(labels_perm, n_speakers_list)] |
| | | |
| | | return min_loss, labels_perm |
| | | def cal_power_loss(logits, power_ts): |
| | | losses = [F.cross_entropy(input=logit, target=power_t.to(torch.long)) * len(logit) for logit, power_t in |
| | | zip(logits, power_ts)] |
| | | loss = torch.sum(torch.stack(losses)) |
| | | n_frames = torch.from_numpy(np.array(np.sum([power_t.shape[0] for power_t in power_ts]))).to(torch.float32).to( |
| | | power_ts[0].device) |
| | | loss = loss / n_frames |
| | | return loss |