嘉渊
2023-07-06 ff0310bfb4ed69f00cbeab89a58f958ae5091d70
update eend_ola
7个文件已修改
1个文件已添加
354 ■■■■ 已修改文件
funasr/build_utils/build_args.py 6 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/build_utils/build_dataloader.py 17 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/build_utils/build_diar_model.py 6 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/small_datasets/sequence_iter_factory.py 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/e2e_diar_eend_ola.py 167 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/modules/eend_ola/eend_ola_dataloader.py 57 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/modules/eend_ola/encoder.py 20 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/modules/eend_ola/utils/losses.py 77 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/build_utils/build_args.py
@@ -86,6 +86,12 @@
        from funasr.build_utils.build_diar_model import class_choices_list
        for class_choices in class_choices_list:
            class_choices.add_arguments(task_parser)
        task_parser.add_argument(
            "--input_size",
            type=int_or_none,
            default=None,
            help="The number of input dimension of the feature",
        )
    elif args.task_name == "sv":
        from funasr.build_utils.build_sv_model import class_choices_list
funasr/build_utils/build_dataloader.py
@@ -4,8 +4,21 @@
def build_dataloader(args):
    if args.dataset_type == "small":
        train_iter_factory = SequenceIterFactory(args, mode="train")
        valid_iter_factory = SequenceIterFactory(args, mode="valid")
        if args.task_name == "diar" and args.model == "eend_ola":
            from funasr.modules.eend_ola.eend_ola_dataloader import EENDOLADataLoader
            train_iter_factory = EENDOLADataLoader(
                data_file=args.train_data_path_and_name_and_type[0][0],
                batch_size=args.dataset_conf["batch_conf"]["batch_size"],
                num_workers=args.dataset_conf["num_workers"],
                shuffle=True)
            valid_iter_factory = EENDOLADataLoader(
                data_file=args.valid_data_path_and_name_and_type[0][0],
                batch_size=args.dataset_conf["batch_conf"]["batch_size"],
                num_workers=0,
                shuffle=False)
        else:
            train_iter_factory = SequenceIterFactory(args, mode="train")
            valid_iter_factory = SequenceIterFactory(args, mode="valid")
    elif args.dataset_type == "large":
        train_iter_factory = LargeDataLoader(args, mode="train")
        valid_iter_factory = LargeDataLoader(args, mode="valid")
funasr/build_utils/build_diar_model.py
@@ -198,16 +198,14 @@
            frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
        else:
            frontend = frontend_class(**args.frontend_conf)
        input_size = frontend.output_size()
    else:
        args.frontend = None
        args.frontend_conf = {}
        frontend = None
        input_size = args.input_size
    # encoder
    encoder_class = encoder_choices.get_class(args.encoder)
    encoder = encoder_class(input_size=input_size, **args.encoder_conf)
    encoder = encoder_class(**args.encoder_conf)
    if args.model == "sond":
        # data augmentation for spectrogram
@@ -272,7 +270,7 @@
            **args.model_conf,
        )
    elif args.model_name == "eend_ola":
    elif args.model == "eend_ola":
        # encoder-decoder attractor
        encoder_decoder_attractor_class = encoder_decoder_attractor_choices.get_class(args.encoder_decoder_attractor)
        encoder_decoder_attractor = encoder_decoder_attractor_class(**args.encoder_decoder_attractor_conf)
funasr/datasets/small_datasets/sequence_iter_factory.py
@@ -57,7 +57,7 @@
            data_path_and_name_and_type,
            preprocess=preprocess_fn,
            dest_sample_rate=dest_sample_rate,
            speed_perturb=args.speed_perturb if mode=="train" else None,
            speed_perturb=args.speed_perturb if mode == "train" else None,
        )
        # sampler
@@ -84,7 +84,7 @@
            args.max_update = len(bs_list) * args.max_epoch
            logging.info("Max update: {}".format(args.max_update))
        if args.distributed and mode=="train":
        if args.distributed and mode == "train":
            world_size = torch.distributed.get_world_size()
            rank = torch.distributed.get_rank()
            for batch in batches:
funasr/models/e2e_diar_eend_ola.py
@@ -1,21 +1,21 @@
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
from contextlib import contextmanager
from distutils.version import LooseVersion
from typing import Dict
from typing import Tuple
from typing import Dict, List, Tuple, Optional
import numpy as np
import torch
import torch.nn as  nn
import torch.nn.functional as F
from typeguard import check_argument_types
from funasr.models.base_model import FunASRModel
from funasr.models.frontend.wav_frontend import WavFrontendMel23
from funasr.modules.eend_ola.encoder import EENDOLATransformerEncoder
from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor
from funasr.modules.eend_ola.utils.losses import fast_batch_pit_n_speaker_loss, standard_loss, cal_power_loss
from funasr.modules.eend_ola.utils.power import create_powerlabel
from funasr.modules.eend_ola.utils.power import generate_mapping_dict
from funasr.torch_utils.device_funcs import force_gatherable
from funasr.models.base_model import FunASRModel
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
    pass
@@ -33,12 +33,35 @@
    return att
def pad_labels(ts, out_size):
    for i, t in enumerate(ts):
        if t.shape[1] < out_size:
            ts[i] = F.pad(
                t,
                (0, out_size - t.shape[1], 0, 0),
                mode='constant',
                value=0.
            )
    return ts
def pad_results(ys, out_size):
    ys_padded = []
    for i, y in enumerate(ys):
        if y.shape[1] < out_size:
            ys_padded.append(
                torch.cat([y, torch.zeros(y.shape[0], out_size - y.shape[1]).to(torch.float32).to(y.device)], dim=1))
        else:
            ys_padded.append(y)
    return ys_padded
class DiarEENDOLAModel(FunASRModel):
    """EEND-OLA diarization model"""
    def __init__(
            self,
            frontend: WavFrontendMel23,
            frontend: Optional[WavFrontendMel23],
            encoder: EENDOLATransformerEncoder,
            encoder_decoder_attractor: EncoderDecoderAttractor,
            n_units: int = 256,
@@ -47,11 +70,12 @@
            mapping_dict=None,
            **kwargs,
    ):
        assert check_argument_types()
        super().__init__()
        self.frontend = frontend
        self.enc = encoder
        self.eda = encoder_decoder_attractor
        self.encoder_decoder_attractor = encoder_decoder_attractor
        self.attractor_loss_weight = attractor_loss_weight
        self.max_n_speaker = max_n_speaker
        if mapping_dict is None:
@@ -74,7 +98,8 @@
    def forward_post_net(self, logits, ilens):
        maxlen = torch.max(ilens).to(torch.int).item()
        logits = nn.utils.rnn.pad_sequence(logits, batch_first=True, padding_value=-1)
        logits = nn.utils.rnn.pack_padded_sequence(logits, ilens.cpu().to(torch.int64), batch_first=True, enforce_sorted=False)
        logits = nn.utils.rnn.pack_padded_sequence(logits, ilens.cpu().to(torch.int64), batch_first=True,
                                                   enforce_sorted=False)
        outputs, (_, _) = self.postnet(logits)
        outputs = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True, padding_value=-1, total_length=maxlen)[0]
        outputs = [output[:ilens[i].to(torch.int).item()] for i, output in enumerate(outputs)]
@@ -83,95 +108,51 @@
    def forward(
            self,
            speech: torch.Tensor,
            speech_lengths: torch.Tensor,
            text: torch.Tensor,
            text_lengths: torch.Tensor,
            speech: List[torch.Tensor],
            speech_lengths: torch.Tensor,  # num_frames of each sample
            speaker_labels: List[torch.Tensor],
            speaker_labels_lengths: torch.Tensor,  # num_speakers of each sample
            orders: torch.Tensor,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Frontend + Encoder + Decoder + Calc loss
        Args:
            speech: (Batch, Length, ...)
            speech_lengths: (Batch, )
            text: (Batch, Length)
            text_lengths: (Batch,)
        """
        assert text_lengths.dim() == 1, text_lengths.shape
        # Check that batch_size is unified
        assert (
                speech.shape[0]
                == speech_lengths.shape[0]
                == text.shape[0]
                == text_lengths.shape[0]
        ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
        batch_size = speech.shape[0]
                len(speech)
                == len(speech_lengths)
                == len(speaker_labels)
                == len(speaker_labels_lengths)
        ), (len(speech), len(speech_lengths), len(speaker_labels), len(speaker_labels_lengths))
        batch_size = len(speech)
        # for data-parallel
        text = text[:, : text_lengths.max()]
        # Encoder
        speech = [s[:s_len] for s, s_len in zip(speech, speech_lengths)]
        encoder_out = self.forward_encoder(speech, speech_lengths)
        # 1. Encoder
        encoder_out, encoder_out_lens = self.enc(speech, speech_lengths)
        intermediate_outs = None
        if isinstance(encoder_out, tuple):
            intermediate_outs = encoder_out[1]
            encoder_out = encoder_out[0]
        # Encoder-decoder attractor
        attractor_loss, attractors = self.encoder_decoder_attractor([e[order] for e, order in zip(encoder_out, orders)],
                                                                    speaker_labels_lengths)
        speaker_logits = [torch.matmul(e, att.permute(1, 0)) for e, att in zip(encoder_out, attractors)]
        loss_att, acc_att, cer_att, wer_att = None, None, None, None
        loss_ctc, cer_ctc = None, None
        # pit loss
        pit_speaker_labels = fast_batch_pit_n_speaker_loss(speaker_logits, speaker_labels)
        pit_loss = standard_loss(speaker_logits, pit_speaker_labels)
        # pse loss
        with torch.no_grad():
            power_ts = [create_powerlabel(label.cpu().numpy(), self.mapping_dict, self.max_n_speaker).
                            to(encoder_out[0].device, non_blocking=True) for label in pit_speaker_labels]
        pad_attractors = [pad_attractor(att, self.max_n_speaker) for att in attractors]
        pse_speaker_logits = [torch.matmul(e, pad_att.permute(1, 0)) for e, pad_att in zip(encoder_out, pad_attractors)]
        pse_speaker_logits = self.forward_post_net(pse_speaker_logits, speech_lengths)
        pse_loss = cal_power_loss(pse_speaker_logits, power_ts)
        loss = pse_loss + pit_loss + self.attractor_loss_weight * attractor_loss
        stats = dict()
        # 1. CTC branch
        if self.ctc_weight != 0.0:
            loss_ctc, cer_ctc = self._calc_ctc_loss(
                encoder_out, encoder_out_lens, text, text_lengths
            )
            # Collect CTC branch stats
            stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
            stats["cer_ctc"] = cer_ctc
        # Intermediate CTC (optional)
        loss_interctc = 0.0
        if self.interctc_weight != 0.0 and intermediate_outs is not None:
            for layer_idx, intermediate_out in intermediate_outs:
                # we assume intermediate_out has the same length & padding
                # as those of encoder_out
                loss_ic, cer_ic = self._calc_ctc_loss(
                    intermediate_out, encoder_out_lens, text, text_lengths
                )
                loss_interctc = loss_interctc + loss_ic
                # Collect Intermedaite CTC stats
                stats["loss_interctc_layer{}".format(layer_idx)] = (
                    loss_ic.detach() if loss_ic is not None else None
                )
                stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
            loss_interctc = loss_interctc / len(intermediate_outs)
            # calculate whole encoder loss
            loss_ctc = (
                               1 - self.interctc_weight
                       ) * loss_ctc + self.interctc_weight * loss_interctc
        # 2b. Attention decoder branch
        if self.ctc_weight != 1.0:
            loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
                encoder_out, encoder_out_lens, text, text_lengths
            )
        # 3. CTC-Att loss definition
        if self.ctc_weight == 0.0:
            loss = loss_att
        elif self.ctc_weight == 1.0:
            loss = loss_ctc
        else:
            loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
        # Collect Attn branch stats
        stats["loss_att"] = loss_att.detach() if loss_att is not None else None
        stats["acc"] = acc_att
        stats["cer"] = cer_att
        stats["wer"] = wer_att
        stats["pse_loss"] = pse_loss.detach()
        stats["pit_loss"] = pit_loss.detach()
        stats["attractor_loss"] = attractor_loss.detach()
        stats["batch_size"] = batch_size
        # Collect total loss stats
        stats["loss"] = torch.clone(loss.detach())
@@ -193,10 +174,10 @@
            orders = [np.arange(e.shape[0]) for e in emb]
            for order in orders:
                np.random.shuffle(order)
            attractors, probs = self.eda.estimate(
            attractors, probs = self.encoder_decoder_attractor.estimate(
                [e[torch.from_numpy(order).to(torch.long).to(speech[0].device)] for e, order in zip(emb, orders)])
        else:
            attractors, probs = self.eda.estimate(emb)
            attractors, probs = self.encoder_decoder_attractor.estimate(emb)
        attractors_active = []
        for p, att, e in zip(probs, attractors, emb):
            if n_speakers and n_speakers >= 0:
funasr/modules/eend_ola/eend_ola_dataloader.py
New file
@@ -0,0 +1,57 @@
import logging
import kaldiio
import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
def custom_collate(batch):
    keys, speech, speaker_labels, orders = zip(*batch)
    speech = [torch.from_numpy(np.copy(sph)).to(torch.float32) for sph in speech]
    speaker_labels = [torch.from_numpy(np.copy(spk)).to(torch.float32) for spk in speaker_labels]
    orders = [torch.from_numpy(np.copy(o)).to(torch.int64) for o in orders]
    batch = dict(speech=speech,
                 speaker_labels=speaker_labels,
                 orders=orders)
    return keys, batch
class EENDOLADataset(Dataset):
    def __init__(
            self,
            data_file,
    ):
        self.data_file = data_file
        with open(data_file) as f:
            lines = f.readlines()
        self.samples = [line.strip().split() for line in lines]
        logging.info("total samples: {}".format(len(self.samples)))
    def __len__(self):
        return len(self.samples)
    def __getitem__(self, idx):
        key, speech_path, speaker_label_path = self.samples[idx]
        speech = kaldiio.load_mat(speech_path)
        speaker_label = kaldiio.load_mat(speaker_label_path).reshape(speech.shape[0], -1)
        order = np.arange(speech.shape[0])
        np.random.shuffle(order)
        return key, speech, speaker_label, order
class EENDOLADataLoader():
    def __init__(self, data_file, batch_size, shuffle=True, num_workers=8):
        dataset = EENDOLADataset(data_file)
        self.data_loader = DataLoader(dataset,
                                      batch_size=batch_size,
                                      collate_fn=custom_collate,
                                      shuffle=shuffle,
                                      num_workers=num_workers)
    def build_iter(self, epoch):
        return self.data_loader
funasr/modules/eend_ola/encoder.py
@@ -91,6 +91,7 @@
                 dropout_rate: float = 0.1,
                 use_pos_emb: bool = False):
        super(EENDOLATransformerEncoder, self).__init__()
        self.linear_in = nn.Linear(idim, n_units)
        self.lnorm_in = nn.LayerNorm(n_units)
        self.n_layers = n_layers
        self.dropout = nn.Dropout(dropout_rate)
@@ -104,25 +105,10 @@
            setattr(self, '{}{:d}'.format("ff_", i),
                    PositionwiseFeedForward(n_units, e_units, dropout_rate))
        self.lnorm_out = nn.LayerNorm(n_units)
        if use_pos_emb:
            self.pos_enc = torch.nn.Sequential(
                torch.nn.Linear(idim, n_units),
                torch.nn.LayerNorm(n_units),
                torch.nn.Dropout(dropout_rate),
                torch.nn.ReLU(),
                PositionalEncoding(n_units, dropout_rate),
            )
        else:
            self.linear_in = nn.Linear(idim, n_units)
            self.pos_enc = None
    def __call__(self, x, x_mask=None):
        BT_size = x.shape[0] * x.shape[1]
        if self.pos_enc is not None:
            e = self.pos_enc(x)
            e = e.view(BT_size, -1)
        else:
            e = self.linear_in(x.reshape(BT_size, -1))
        e = self.linear_in(x.reshape(BT_size, -1))
        for i in range(self.n_layers):
            e = getattr(self, '{}{:d}'.format("lnorm1_", i))(e)
            s = getattr(self, '{}{:d}'.format("self_att_", i))(e, x.shape[0], x_mask)
@@ -130,4 +116,4 @@
            e = getattr(self, '{}{:d}'.format("lnorm2_", i))(e)
            s = getattr(self, '{}{:d}'.format("ff_", i))(e)
            e = e + self.dropout(s)
        return self.lnorm_out(e)
        return self.lnorm_out(e)
funasr/modules/eend_ola/utils/losses.py
@@ -1,11 +1,10 @@
import numpy as np
import torch
import torch.nn.functional as F
from itertools import permutations
from torch import nn
from scipy.optimize import linear_sum_assignment
def standard_loss(ys, ts, label_delay=0):
def standard_loss(ys, ts):
    losses = [F.binary_cross_entropy(torch.sigmoid(y), t) * len(y) for y, t in zip(ys, ts)]
    loss = torch.sum(torch.stack(losses))
    n_frames = torch.from_numpy(np.array(np.sum([t.shape[0] for t in ts]))).to(torch.float32).to(ys[0].device)
@@ -13,55 +12,29 @@
    return loss
def batch_pit_n_speaker_loss(ys, ts, n_speakers_list):
    max_n_speakers = ts[0].shape[1]
    olens = [y.shape[0] for y in ys]
    ys = nn.utils.rnn.pad_sequence(ys, batch_first=True, padding_value=-1)
    ys_mask = [torch.ones(olen).to(ys.device) for olen in olens]
    ys_mask = torch.nn.utils.rnn.pad_sequence(ys_mask, batch_first=True, padding_value=0).unsqueeze(-1)
def fast_batch_pit_n_speaker_loss(ys, ts):
    with torch.no_grad():
        bs = len(ys)
        indices = []
        for b in range(bs):
            y = ys[b].transpose(0, 1)
            t = ts[b].transpose(0, 1)
            C, _ = t.shape
            y = y[:, None, :].repeat(1, C, 1)
            t = t[None, :, :].repeat(C, 1, 1)
            bce_loss = F.binary_cross_entropy(torch.sigmoid(y), t, reduction="none").mean(-1)
            C = bce_loss.cpu()
            indices.append(linear_sum_assignment(C))
    labels_perm = [t[:, idx[1]] for t, idx in zip(ts, indices)]
    losses = []
    for shift in range(max_n_speakers):
        ts_roll = [torch.roll(t, -shift, dims=1) for t in ts]
        ts_roll = nn.utils.rnn.pad_sequence(ts_roll, batch_first=True, padding_value=-1)
        loss = F.binary_cross_entropy(torch.sigmoid(ys), ts_roll, reduction='none')
        if ys_mask is not None:
            loss = loss * ys_mask
        loss = torch.sum(loss, dim=1)
        losses.append(loss)
    losses = torch.stack(losses, dim=2)
    return labels_perm
    perms = np.array(list(permutations(range(max_n_speakers)))).astype(np.float32)
    perms = torch.from_numpy(perms).to(losses.device)
    y_ind = torch.arange(max_n_speakers, dtype=torch.float32, device=losses.device)
    t_inds = torch.fmod(perms - y_ind, max_n_speakers).to(torch.long)
    losses_perm = []
    for t_ind in t_inds:
        losses_perm.append(
            torch.mean(losses[:, y_ind.to(torch.long), t_ind], dim=1))
    losses_perm = torch.stack(losses_perm, dim=1)
    def select_perm_indices(num, max_num):
        perms = list(permutations(range(max_num)))
        sub_perms = list(permutations(range(num)))
        return [
            [x[:num] for x in perms].index(perm)
            for perm in sub_perms]
    masks = torch.full_like(losses_perm, device=losses.device, fill_value=float('inf'))
    for i, t in enumerate(ts):
        n_speakers = n_speakers_list[i]
        indices = select_perm_indices(n_speakers, max_n_speakers)
        masks[i, indices] = 0
    losses_perm += masks
    min_loss = torch.sum(torch.min(losses_perm, dim=1)[0])
    n_frames = torch.from_numpy(np.array(np.sum([t.shape[0] for t in ts]))).to(losses.device)
    min_loss = min_loss / n_frames
    min_indices = torch.argmin(losses_perm, dim=1)
    labels_perm = [t[:, perms[idx].to(torch.long)] for t, idx in zip(ts, min_indices)]
    labels_perm = [t[:, :n_speakers] for t, n_speakers in zip(labels_perm, n_speakers_list)]
    return min_loss, labels_perm
def cal_power_loss(logits, power_ts):
    losses = [F.cross_entropy(input=logit, target=power_t.to(torch.long)) * len(logit) for logit, power_t in
              zip(logits, power_ts)]
    loss = torch.sum(torch.stack(losses))
    n_frames = torch.from_numpy(np.array(np.sum([power_t.shape[0] for power_t in power_ts]))).to(torch.float32).to(
        power_ts[0].device)
    loss = loss / n_frames
    return loss