From ff0310bfb4ed69f00cbeab89a58f958ae5091d70 Mon Sep 17 00:00:00 2001
From: 嘉渊 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期四, 06 七月 2023 16:24:35 +0800
Subject: [PATCH] update eend_ola
---
funasr/datasets/small_datasets/sequence_iter_factory.py | 4
funasr/build_utils/build_args.py | 6 +
funasr/modules/eend_ola/encoder.py | 20 --
funasr/build_utils/build_dataloader.py | 17 ++
funasr/modules/eend_ola/utils/losses.py | 77 ++++--------
funasr/build_utils/build_diar_model.py | 6
funasr/modules/eend_ola/eend_ola_dataloader.py | 57 +++++++++
funasr/models/e2e_diar_eend_ola.py | 167 ++++++++++++---------------
8 files changed, 184 insertions(+), 170 deletions(-)
diff --git a/funasr/build_utils/build_args.py b/funasr/build_utils/build_args.py
index 632c134..31f210e 100644
--- a/funasr/build_utils/build_args.py
+++ b/funasr/build_utils/build_args.py
@@ -86,6 +86,12 @@
from funasr.build_utils.build_diar_model import class_choices_list
for class_choices in class_choices_list:
class_choices.add_arguments(task_parser)
+ task_parser.add_argument(
+ "--input_size",
+ type=int_or_none,
+ default=None,
+ help="The number of input dimension of the feature",
+ )
elif args.task_name == "sv":
from funasr.build_utils.build_sv_model import class_choices_list
diff --git a/funasr/build_utils/build_dataloader.py b/funasr/build_utils/build_dataloader.py
index c95c40d..473097e 100644
--- a/funasr/build_utils/build_dataloader.py
+++ b/funasr/build_utils/build_dataloader.py
@@ -4,8 +4,21 @@
def build_dataloader(args):
if args.dataset_type == "small":
- train_iter_factory = SequenceIterFactory(args, mode="train")
- valid_iter_factory = SequenceIterFactory(args, mode="valid")
+ if args.task_name == "diar" and args.model == "eend_ola":
+ from funasr.modules.eend_ola.eend_ola_dataloader import EENDOLADataLoader
+ train_iter_factory = EENDOLADataLoader(
+ data_file=args.train_data_path_and_name_and_type[0][0],
+ batch_size=args.dataset_conf["batch_conf"]["batch_size"],
+ num_workers=args.dataset_conf["num_workers"],
+ shuffle=True)
+ valid_iter_factory = EENDOLADataLoader(
+ data_file=args.valid_data_path_and_name_and_type[0][0],
+ batch_size=args.dataset_conf["batch_conf"]["batch_size"],
+ num_workers=0,
+ shuffle=False)
+ else:
+ train_iter_factory = SequenceIterFactory(args, mode="train")
+ valid_iter_factory = SequenceIterFactory(args, mode="valid")
elif args.dataset_type == "large":
train_iter_factory = LargeDataLoader(args, mode="train")
valid_iter_factory = LargeDataLoader(args, mode="valid")
diff --git a/funasr/build_utils/build_diar_model.py b/funasr/build_utils/build_diar_model.py
index 0ea3127..444636a 100644
--- a/funasr/build_utils/build_diar_model.py
+++ b/funasr/build_utils/build_diar_model.py
@@ -198,16 +198,14 @@
frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
else:
frontend = frontend_class(**args.frontend_conf)
- input_size = frontend.output_size()
else:
args.frontend = None
args.frontend_conf = {}
frontend = None
- input_size = args.input_size
# encoder
encoder_class = encoder_choices.get_class(args.encoder)
- encoder = encoder_class(input_size=input_size, **args.encoder_conf)
+ encoder = encoder_class(**args.encoder_conf)
if args.model == "sond":
# data augmentation for spectrogram
@@ -272,7 +270,7 @@
**args.model_conf,
)
- elif args.model_name == "eend_ola":
+ elif args.model == "eend_ola":
# encoder-decoder attractor
encoder_decoder_attractor_class = encoder_decoder_attractor_choices.get_class(args.encoder_decoder_attractor)
encoder_decoder_attractor = encoder_decoder_attractor_class(**args.encoder_decoder_attractor_conf)
diff --git a/funasr/datasets/small_datasets/sequence_iter_factory.py b/funasr/datasets/small_datasets/sequence_iter_factory.py
index 3ebcc5a..e748c3d 100644
--- a/funasr/datasets/small_datasets/sequence_iter_factory.py
+++ b/funasr/datasets/small_datasets/sequence_iter_factory.py
@@ -57,7 +57,7 @@
data_path_and_name_and_type,
preprocess=preprocess_fn,
dest_sample_rate=dest_sample_rate,
- speed_perturb=args.speed_perturb if mode=="train" else None,
+ speed_perturb=args.speed_perturb if mode == "train" else None,
)
# sampler
@@ -84,7 +84,7 @@
args.max_update = len(bs_list) * args.max_epoch
logging.info("Max update: {}".format(args.max_update))
- if args.distributed and mode=="train":
+ if args.distributed and mode == "train":
world_size = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
for batch in batches:
diff --git a/funasr/models/e2e_diar_eend_ola.py b/funasr/models/e2e_diar_eend_ola.py
index ae3a436..af0fd62 100644
--- a/funasr/models/e2e_diar_eend_ola.py
+++ b/funasr/models/e2e_diar_eend_ola.py
@@ -1,21 +1,21 @@
-# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
-# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
-
from contextlib import contextmanager
from distutils.version import LooseVersion
-from typing import Dict
-from typing import Tuple
+from typing import Dict, List, Tuple, Optional
import numpy as np
import torch
import torch.nn as nn
+import torch.nn.functional as F
+from typeguard import check_argument_types
+from funasr.models.base_model import FunASRModel
from funasr.models.frontend.wav_frontend import WavFrontendMel23
from funasr.modules.eend_ola.encoder import EENDOLATransformerEncoder
from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor
+from funasr.modules.eend_ola.utils.losses import fast_batch_pit_n_speaker_loss, standard_loss, cal_power_loss
+from funasr.modules.eend_ola.utils.power import create_powerlabel
from funasr.modules.eend_ola.utils.power import generate_mapping_dict
from funasr.torch_utils.device_funcs import force_gatherable
-from funasr.models.base_model import FunASRModel
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
pass
@@ -33,12 +33,35 @@
return att
+def pad_labels(ts, out_size):
+ for i, t in enumerate(ts):
+ if t.shape[1] < out_size:
+ ts[i] = F.pad(
+ t,
+ (0, out_size - t.shape[1], 0, 0),
+ mode='constant',
+ value=0.
+ )
+ return ts
+
+
+def pad_results(ys, out_size):
+ ys_padded = []
+ for i, y in enumerate(ys):
+ if y.shape[1] < out_size:
+ ys_padded.append(
+ torch.cat([y, torch.zeros(y.shape[0], out_size - y.shape[1]).to(torch.float32).to(y.device)], dim=1))
+ else:
+ ys_padded.append(y)
+ return ys_padded
+
+
class DiarEENDOLAModel(FunASRModel):
"""EEND-OLA diarization model"""
def __init__(
self,
- frontend: WavFrontendMel23,
+ frontend: Optional[WavFrontendMel23],
encoder: EENDOLATransformerEncoder,
encoder_decoder_attractor: EncoderDecoderAttractor,
n_units: int = 256,
@@ -47,11 +70,12 @@
mapping_dict=None,
**kwargs,
):
+ assert check_argument_types()
super().__init__()
self.frontend = frontend
self.enc = encoder
- self.eda = encoder_decoder_attractor
+ self.encoder_decoder_attractor = encoder_decoder_attractor
self.attractor_loss_weight = attractor_loss_weight
self.max_n_speaker = max_n_speaker
if mapping_dict is None:
@@ -74,7 +98,8 @@
def forward_post_net(self, logits, ilens):
maxlen = torch.max(ilens).to(torch.int).item()
logits = nn.utils.rnn.pad_sequence(logits, batch_first=True, padding_value=-1)
- logits = nn.utils.rnn.pack_padded_sequence(logits, ilens.cpu().to(torch.int64), batch_first=True, enforce_sorted=False)
+ logits = nn.utils.rnn.pack_padded_sequence(logits, ilens.cpu().to(torch.int64), batch_first=True,
+ enforce_sorted=False)
outputs, (_, _) = self.postnet(logits)
outputs = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True, padding_value=-1, total_length=maxlen)[0]
outputs = [output[:ilens[i].to(torch.int).item()] for i, output in enumerate(outputs)]
@@ -83,95 +108,51 @@
def forward(
self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- text: torch.Tensor,
- text_lengths: torch.Tensor,
+ speech: List[torch.Tensor],
+ speech_lengths: torch.Tensor, # num_frames of each sample
+ speaker_labels: List[torch.Tensor],
+ speaker_labels_lengths: torch.Tensor, # num_speakers of each sample
+ orders: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
- """Frontend + Encoder + Decoder + Calc loss
- Args:
- speech: (Batch, Length, ...)
- speech_lengths: (Batch, )
- text: (Batch, Length)
- text_lengths: (Batch,)
- """
- assert text_lengths.dim() == 1, text_lengths.shape
+
# Check that batch_size is unified
assert (
- speech.shape[0]
- == speech_lengths.shape[0]
- == text.shape[0]
- == text_lengths.shape[0]
- ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
- batch_size = speech.shape[0]
+ len(speech)
+ == len(speech_lengths)
+ == len(speaker_labels)
+ == len(speaker_labels_lengths)
+ ), (len(speech), len(speech_lengths), len(speaker_labels), len(speaker_labels_lengths))
+ batch_size = len(speech)
- # for data-parallel
- text = text[:, : text_lengths.max()]
+ # Encoder
+ speech = [s[:s_len] for s, s_len in zip(speech, speech_lengths)]
+ encoder_out = self.forward_encoder(speech, speech_lengths)
- # 1. Encoder
- encoder_out, encoder_out_lens = self.enc(speech, speech_lengths)
- intermediate_outs = None
- if isinstance(encoder_out, tuple):
- intermediate_outs = encoder_out[1]
- encoder_out = encoder_out[0]
+ # Encoder-decoder attractor
+ attractor_loss, attractors = self.encoder_decoder_attractor([e[order] for e, order in zip(encoder_out, orders)],
+ speaker_labels_lengths)
+ speaker_logits = [torch.matmul(e, att.permute(1, 0)) for e, att in zip(encoder_out, attractors)]
- loss_att, acc_att, cer_att, wer_att = None, None, None, None
- loss_ctc, cer_ctc = None, None
+ # pit loss
+ pit_speaker_labels = fast_batch_pit_n_speaker_loss(speaker_logits, speaker_labels)
+ pit_loss = standard_loss(speaker_logits, pit_speaker_labels)
+
+ # pse loss
+ with torch.no_grad():
+ power_ts = [create_powerlabel(label.cpu().numpy(), self.mapping_dict, self.max_n_speaker).
+ to(encoder_out[0].device, non_blocking=True) for label in pit_speaker_labels]
+ pad_attractors = [pad_attractor(att, self.max_n_speaker) for att in attractors]
+ pse_speaker_logits = [torch.matmul(e, pad_att.permute(1, 0)) for e, pad_att in zip(encoder_out, pad_attractors)]
+ pse_speaker_logits = self.forward_post_net(pse_speaker_logits, speech_lengths)
+ pse_loss = cal_power_loss(pse_speaker_logits, power_ts)
+
+ loss = pse_loss + pit_loss + self.attractor_loss_weight * attractor_loss
+
stats = dict()
-
- # 1. CTC branch
- if self.ctc_weight != 0.0:
- loss_ctc, cer_ctc = self._calc_ctc_loss(
- encoder_out, encoder_out_lens, text, text_lengths
- )
-
- # Collect CTC branch stats
- stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
- stats["cer_ctc"] = cer_ctc
-
- # Intermediate CTC (optional)
- loss_interctc = 0.0
- if self.interctc_weight != 0.0 and intermediate_outs is not None:
- for layer_idx, intermediate_out in intermediate_outs:
- # we assume intermediate_out has the same length & padding
- # as those of encoder_out
- loss_ic, cer_ic = self._calc_ctc_loss(
- intermediate_out, encoder_out_lens, text, text_lengths
- )
- loss_interctc = loss_interctc + loss_ic
-
- # Collect Intermedaite CTC stats
- stats["loss_interctc_layer{}".format(layer_idx)] = (
- loss_ic.detach() if loss_ic is not None else None
- )
- stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
-
- loss_interctc = loss_interctc / len(intermediate_outs)
-
- # calculate whole encoder loss
- loss_ctc = (
- 1 - self.interctc_weight
- ) * loss_ctc + self.interctc_weight * loss_interctc
-
- # 2b. Attention decoder branch
- if self.ctc_weight != 1.0:
- loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
- encoder_out, encoder_out_lens, text, text_lengths
- )
-
- # 3. CTC-Att loss definition
- if self.ctc_weight == 0.0:
- loss = loss_att
- elif self.ctc_weight == 1.0:
- loss = loss_ctc
- else:
- loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
-
- # Collect Attn branch stats
- stats["loss_att"] = loss_att.detach() if loss_att is not None else None
- stats["acc"] = acc_att
- stats["cer"] = cer_att
- stats["wer"] = wer_att
+ stats["pse_loss"] = pse_loss.detach()
+ stats["pit_loss"] = pit_loss.detach()
+ stats["attractor_loss"] = attractor_loss.detach()
+ stats["batch_size"] = batch_size
# Collect total loss stats
stats["loss"] = torch.clone(loss.detach())
@@ -193,10 +174,10 @@
orders = [np.arange(e.shape[0]) for e in emb]
for order in orders:
np.random.shuffle(order)
- attractors, probs = self.eda.estimate(
+ attractors, probs = self.encoder_decoder_attractor.estimate(
[e[torch.from_numpy(order).to(torch.long).to(speech[0].device)] for e, order in zip(emb, orders)])
else:
- attractors, probs = self.eda.estimate(emb)
+ attractors, probs = self.encoder_decoder_attractor.estimate(emb)
attractors_active = []
for p, att, e in zip(probs, attractors, emb):
if n_speakers and n_speakers >= 0:
diff --git a/funasr/modules/eend_ola/eend_ola_dataloader.py b/funasr/modules/eend_ola/eend_ola_dataloader.py
new file mode 100644
index 0000000..2ee9272
--- /dev/null
+++ b/funasr/modules/eend_ola/eend_ola_dataloader.py
@@ -0,0 +1,57 @@
+import logging
+
+import kaldiio
+import numpy as np
+import torch
+from torch.utils.data import DataLoader
+from torch.utils.data import Dataset
+
+
+def custom_collate(batch):
+ keys, speech, speaker_labels, orders = zip(*batch)
+ speech = [torch.from_numpy(np.copy(sph)).to(torch.float32) for sph in speech]
+ speaker_labels = [torch.from_numpy(np.copy(spk)).to(torch.float32) for spk in speaker_labels]
+ orders = [torch.from_numpy(np.copy(o)).to(torch.int64) for o in orders]
+ batch = dict(speech=speech,
+ speaker_labels=speaker_labels,
+ orders=orders)
+
+ return keys, batch
+
+
+class EENDOLADataset(Dataset):
+ def __init__(
+ self,
+ data_file,
+ ):
+ self.data_file = data_file
+ with open(data_file) as f:
+ lines = f.readlines()
+ self.samples = [line.strip().split() for line in lines]
+ logging.info("total samples: {}".format(len(self.samples)))
+
+ def __len__(self):
+ return len(self.samples)
+
+ def __getitem__(self, idx):
+ key, speech_path, speaker_label_path = self.samples[idx]
+ speech = kaldiio.load_mat(speech_path)
+ speaker_label = kaldiio.load_mat(speaker_label_path).reshape(speech.shape[0], -1)
+
+ order = np.arange(speech.shape[0])
+ np.random.shuffle(order)
+
+ return key, speech, speaker_label, order
+
+
+class EENDOLADataLoader():
+ def __init__(self, data_file, batch_size, shuffle=True, num_workers=8):
+ dataset = EENDOLADataset(data_file)
+ self.data_loader = DataLoader(dataset,
+ batch_size=batch_size,
+ collate_fn=custom_collate,
+ shuffle=shuffle,
+ num_workers=num_workers)
+
+ def build_iter(self, epoch):
+ return self.data_loader
\ No newline at end of file
diff --git a/funasr/modules/eend_ola/encoder.py b/funasr/modules/eend_ola/encoder.py
index 90a63f3..3065884 100644
--- a/funasr/modules/eend_ola/encoder.py
+++ b/funasr/modules/eend_ola/encoder.py
@@ -91,6 +91,7 @@
dropout_rate: float = 0.1,
use_pos_emb: bool = False):
super(EENDOLATransformerEncoder, self).__init__()
+ self.linear_in = nn.Linear(idim, n_units)
self.lnorm_in = nn.LayerNorm(n_units)
self.n_layers = n_layers
self.dropout = nn.Dropout(dropout_rate)
@@ -104,25 +105,10 @@
setattr(self, '{}{:d}'.format("ff_", i),
PositionwiseFeedForward(n_units, e_units, dropout_rate))
self.lnorm_out = nn.LayerNorm(n_units)
- if use_pos_emb:
- self.pos_enc = torch.nn.Sequential(
- torch.nn.Linear(idim, n_units),
- torch.nn.LayerNorm(n_units),
- torch.nn.Dropout(dropout_rate),
- torch.nn.ReLU(),
- PositionalEncoding(n_units, dropout_rate),
- )
- else:
- self.linear_in = nn.Linear(idim, n_units)
- self.pos_enc = None
def __call__(self, x, x_mask=None):
BT_size = x.shape[0] * x.shape[1]
- if self.pos_enc is not None:
- e = self.pos_enc(x)
- e = e.view(BT_size, -1)
- else:
- e = self.linear_in(x.reshape(BT_size, -1))
+ e = self.linear_in(x.reshape(BT_size, -1))
for i in range(self.n_layers):
e = getattr(self, '{}{:d}'.format("lnorm1_", i))(e)
s = getattr(self, '{}{:d}'.format("self_att_", i))(e, x.shape[0], x_mask)
@@ -130,4 +116,4 @@
e = getattr(self, '{}{:d}'.format("lnorm2_", i))(e)
s = getattr(self, '{}{:d}'.format("ff_", i))(e)
e = e + self.dropout(s)
- return self.lnorm_out(e)
+ return self.lnorm_out(e)
\ No newline at end of file
diff --git a/funasr/modules/eend_ola/utils/losses.py b/funasr/modules/eend_ola/utils/losses.py
index af0181d..756952d 100644
--- a/funasr/modules/eend_ola/utils/losses.py
+++ b/funasr/modules/eend_ola/utils/losses.py
@@ -1,11 +1,10 @@
import numpy as np
import torch
import torch.nn.functional as F
-from itertools import permutations
-from torch import nn
+from scipy.optimize import linear_sum_assignment
-def standard_loss(ys, ts, label_delay=0):
+def standard_loss(ys, ts):
losses = [F.binary_cross_entropy(torch.sigmoid(y), t) * len(y) for y, t in zip(ys, ts)]
loss = torch.sum(torch.stack(losses))
n_frames = torch.from_numpy(np.array(np.sum([t.shape[0] for t in ts]))).to(torch.float32).to(ys[0].device)
@@ -13,55 +12,29 @@
return loss
-def batch_pit_n_speaker_loss(ys, ts, n_speakers_list):
- max_n_speakers = ts[0].shape[1]
- olens = [y.shape[0] for y in ys]
- ys = nn.utils.rnn.pad_sequence(ys, batch_first=True, padding_value=-1)
- ys_mask = [torch.ones(olen).to(ys.device) for olen in olens]
- ys_mask = torch.nn.utils.rnn.pad_sequence(ys_mask, batch_first=True, padding_value=0).unsqueeze(-1)
+def fast_batch_pit_n_speaker_loss(ys, ts):
+ with torch.no_grad():
+ bs = len(ys)
+ indices = []
+ for b in range(bs):
+ y = ys[b].transpose(0, 1)
+ t = ts[b].transpose(0, 1)
+ C, _ = t.shape
+ y = y[:, None, :].repeat(1, C, 1)
+ t = t[None, :, :].repeat(C, 1, 1)
+ bce_loss = F.binary_cross_entropy(torch.sigmoid(y), t, reduction="none").mean(-1)
+ C = bce_loss.cpu()
+ indices.append(linear_sum_assignment(C))
+ labels_perm = [t[:, idx[1]] for t, idx in zip(ts, indices)]
- losses = []
- for shift in range(max_n_speakers):
- ts_roll = [torch.roll(t, -shift, dims=1) for t in ts]
- ts_roll = nn.utils.rnn.pad_sequence(ts_roll, batch_first=True, padding_value=-1)
- loss = F.binary_cross_entropy(torch.sigmoid(ys), ts_roll, reduction='none')
- if ys_mask is not None:
- loss = loss * ys_mask
- loss = torch.sum(loss, dim=1)
- losses.append(loss)
- losses = torch.stack(losses, dim=2)
+ return labels_perm
- perms = np.array(list(permutations(range(max_n_speakers)))).astype(np.float32)
- perms = torch.from_numpy(perms).to(losses.device)
- y_ind = torch.arange(max_n_speakers, dtype=torch.float32, device=losses.device)
- t_inds = torch.fmod(perms - y_ind, max_n_speakers).to(torch.long)
- losses_perm = []
- for t_ind in t_inds:
- losses_perm.append(
- torch.mean(losses[:, y_ind.to(torch.long), t_ind], dim=1))
- losses_perm = torch.stack(losses_perm, dim=1)
-
- def select_perm_indices(num, max_num):
- perms = list(permutations(range(max_num)))
- sub_perms = list(permutations(range(num)))
- return [
- [x[:num] for x in perms].index(perm)
- for perm in sub_perms]
-
- masks = torch.full_like(losses_perm, device=losses.device, fill_value=float('inf'))
- for i, t in enumerate(ts):
- n_speakers = n_speakers_list[i]
- indices = select_perm_indices(n_speakers, max_n_speakers)
- masks[i, indices] = 0
- losses_perm += masks
-
- min_loss = torch.sum(torch.min(losses_perm, dim=1)[0])
- n_frames = torch.from_numpy(np.array(np.sum([t.shape[0] for t in ts]))).to(losses.device)
- min_loss = min_loss / n_frames
-
- min_indices = torch.argmin(losses_perm, dim=1)
- labels_perm = [t[:, perms[idx].to(torch.long)] for t, idx in zip(ts, min_indices)]
- labels_perm = [t[:, :n_speakers] for t, n_speakers in zip(labels_perm, n_speakers_list)]
-
- return min_loss, labels_perm
+def cal_power_loss(logits, power_ts):
+ losses = [F.cross_entropy(input=logit, target=power_t.to(torch.long)) * len(logit) for logit, power_t in
+ zip(logits, power_ts)]
+ loss = torch.sum(torch.stack(losses))
+ n_frames = torch.from_numpy(np.array(np.sum([power_t.shape[0] for power_t in power_ts]))).to(torch.float32).to(
+ power_ts[0].device)
+ loss = loss / n_frames
+ return loss
--
Gitblit v1.9.1