From 33d3d2084403fd34b79c835d2f2fe04f6cd8f738 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 13 九月 2023 09:33:54 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR add
---
funasr/models/e2e_diar_eend_ola.py | 510 ++++++++++++++++++--------------------------------------
1 files changed, 168 insertions(+), 342 deletions(-)
diff --git a/funasr/models/e2e_diar_eend_ola.py b/funasr/models/e2e_diar_eend_ola.py
index 967c0d4..a0b545a 100644
--- a/funasr/models/e2e_diar_eend_ola.py
+++ b/funasr/models/e2e_diar_eend_ola.py
@@ -1,38 +1,23 @@
-# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
-# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
-
-import logging
-import torch
from contextlib import contextmanager
from distutils.version import LooseVersion
-from funasr.layers.abs_normalize import AbsNormalize
-from funasr.losses.label_smoothing_loss import (
- LabelSmoothingLoss, # noqa: H301
-)
-from funasr.models.ctc import CTC
-from funasr.models.decoder.abs_decoder import AbsDecoder
-from funasr.models.encoder.abs_encoder import AbsEncoder
-from funasr.models.frontend.abs_frontend import AbsFrontend
-from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
-from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
-from funasr.models.specaug.abs_specaug import AbsSpecAug
-from funasr.modules.add_sos_eos import add_sos_eos
-from funasr.modules.e2e_asr_common import ErrorCalculator
-from funasr.modules.eend_ola.encoder import TransformerEncoder
+from typing import Dict, List, Tuple, Optional
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from funasr.models.base_model import FunASRModel
+from funasr.models.frontend.wav_frontend import WavFrontendMel23
+from funasr.modules.eend_ola.encoder import EENDOLATransformerEncoder
from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor
+from funasr.modules.eend_ola.utils.losses import standard_loss, cal_power_loss, fast_batch_pit_n_speaker_loss
+from funasr.modules.eend_ola.utils.power import create_powerlabel
from funasr.modules.eend_ola.utils.power import generate_mapping_dict
-from funasr.modules.nets_utils import th_accuracy
from funasr.torch_utils.device_funcs import force_gatherable
-from funasr.train.abs_espnet_model import AbsESPnetModel
-from typeguard import check_argument_types
-from typing import Dict
-from typing import List
-from typing import Optional
-from typing import Tuple
-from typing import Union
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
- from torch.cuda.amp import autocast
+ pass
else:
# Nothing to do if torch<1.6.0
@contextmanager
@@ -40,121 +25,125 @@
yield
-class DiarEENDOLAModel(AbsESPnetModel):
- """CTC-attention hybrid Encoder-Decoder model"""
+def pad_attractor(att, max_n_speakers):
+ C, D = att.shape
+ if C < max_n_speakers:
+ att = torch.cat([att, torch.zeros(max_n_speakers - C, D).to(torch.float32).to(att.device)], dim=0)
+ return att
+
+
+def pad_labels(ts, out_size):
+ for i, t in enumerate(ts):
+ if t.shape[1] < out_size:
+ ts[i] = F.pad(
+ t,
+ (0, out_size - t.shape[1], 0, 0),
+ mode='constant',
+ value=0.
+ )
+ return ts
+
+
+def pad_results(ys, out_size):
+ ys_padded = []
+ for i, y in enumerate(ys):
+ if y.shape[1] < out_size:
+ ys_padded.append(
+ torch.cat([y, torch.zeros(y.shape[0], out_size - y.shape[1]).to(torch.float32).to(y.device)], dim=1))
+ else:
+ ys_padded.append(y)
+ return ys_padded
+
+
+class DiarEENDOLAModel(FunASRModel):
+ """EEND-OLA diarization model"""
def __init__(
self,
- encoder: TransformerEncoder,
- eda: EncoderDecoderAttractor,
+ frontend: Optional[WavFrontendMel23],
+ encoder: EENDOLATransformerEncoder,
+ encoder_decoder_attractor: EncoderDecoderAttractor,
+ n_units: int = 256,
max_n_speaker: int = 8,
attractor_loss_weight: float = 1.0,
mapping_dict=None,
**kwargs,
):
- assert check_argument_types()
-
super().__init__()
- self.encoder = encoder
- self.eda = eda
+ self.frontend = frontend
+ self.enc = encoder
+ self.encoder_decoder_attractor = encoder_decoder_attractor
self.attractor_loss_weight = attractor_loss_weight
self.max_n_speaker = max_n_speaker
if mapping_dict is None:
mapping_dict = generate_mapping_dict(max_speaker_num=self.max_n_speaker)
self.mapping_dict = mapping_dict
+ # PostNet
+ self.postnet = nn.LSTM(self.max_n_speaker, n_units, 1, batch_first=True)
+ self.output_layer = nn.Linear(n_units, mapping_dict['oov'] + 1)
+
+ def forward_encoder(self, xs, ilens):
+ xs = nn.utils.rnn.pad_sequence(xs, batch_first=True, padding_value=-1)
+ pad_shape = xs.shape
+ xs_mask = [torch.ones(ilen).to(xs.device) for ilen in ilens]
+ xs_mask = torch.nn.utils.rnn.pad_sequence(xs_mask, batch_first=True, padding_value=0).unsqueeze(-2)
+ emb = self.enc(xs, xs_mask)
+ emb = torch.split(emb.view(pad_shape[0], pad_shape[1], -1), 1, dim=0)
+ emb = [e[0][:ilen] for e, ilen in zip(emb, ilens)]
+ return emb
+
+ def forward_post_net(self, logits, ilens):
+ maxlen = torch.max(ilens).to(torch.int).item()
+ logits = nn.utils.rnn.pad_sequence(logits, batch_first=True, padding_value=-1)
+ logits = nn.utils.rnn.pack_padded_sequence(logits, ilens.cpu().to(torch.int64), batch_first=True,
+ enforce_sorted=False)
+ outputs, (_, _) = self.postnet(logits)
+ outputs = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True, padding_value=-1, total_length=maxlen)[0]
+ outputs = [output[:ilens[i].to(torch.int).item()] for i, output in enumerate(outputs)]
+ outputs = [self.output_layer(output) for output in outputs]
+ return outputs
def forward(
self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- text: torch.Tensor,
- text_lengths: torch.Tensor,
+ speech: List[torch.Tensor],
+ speaker_labels: List[torch.Tensor],
+ orders: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
- """Frontend + Encoder + Decoder + Calc loss
- Args:
- speech: (Batch, Length, ...)
- speech_lengths: (Batch, )
- text: (Batch, Length)
- text_lengths: (Batch,)
- """
- assert text_lengths.dim() == 1, text_lengths.shape
# Check that batch_size is unified
- assert (
- speech.shape[0]
- == speech_lengths.shape[0]
- == text.shape[0]
- == text_lengths.shape[0]
- ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
- batch_size = speech.shape[0]
+ assert (len(speech) == len(speaker_labels)), (len(speech), len(speaker_labels))
+ speech_lengths = torch.tensor([len(sph) for sph in speech]).to(torch.int64)
+ speaker_labels_lengths = torch.tensor([spk.shape[-1] for spk in speaker_labels]).to(torch.int64)
+ batch_size = len(speech)
- # for data-parallel
- text = text[:, : text_lengths.max()]
+ # Encoder
+ encoder_out = self.forward_encoder(speech, speech_lengths)
- # 1. Encoder
- encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
- intermediate_outs = None
- if isinstance(encoder_out, tuple):
- intermediate_outs = encoder_out[1]
- encoder_out = encoder_out[0]
+ # Encoder-decoder attractor
+ attractor_loss, attractors = self.encoder_decoder_attractor([e[order] for e, order in zip(encoder_out, orders)],
+ speaker_labels_lengths)
+ speaker_logits = [torch.matmul(e, att.permute(1, 0)) for e, att in zip(encoder_out, attractors)]
- loss_att, acc_att, cer_att, wer_att = None, None, None, None
- loss_ctc, cer_ctc = None, None
+ # pit loss
+ pit_speaker_labels = fast_batch_pit_n_speaker_loss(speaker_logits, speaker_labels)
+ pit_loss = standard_loss(speaker_logits, pit_speaker_labels)
+
+ # pse loss
+ with torch.no_grad():
+ power_ts = [create_powerlabel(label.cpu().numpy(), self.mapping_dict, self.max_n_speaker).
+ to(encoder_out[0].device, non_blocking=True) for label in pit_speaker_labels]
+ pad_attractors = [pad_attractor(att, self.max_n_speaker) for att in attractors]
+ pse_speaker_logits = [torch.matmul(e, pad_att.permute(1, 0)) for e, pad_att in zip(encoder_out, pad_attractors)]
+ pse_speaker_logits = self.forward_post_net(pse_speaker_logits, speech_lengths)
+ pse_loss = cal_power_loss(pse_speaker_logits, power_ts)
+
+ loss = pse_loss + pit_loss + self.attractor_loss_weight * attractor_loss
+
stats = dict()
-
- # 1. CTC branch
- if self.ctc_weight != 0.0:
- loss_ctc, cer_ctc = self._calc_ctc_loss(
- encoder_out, encoder_out_lens, text, text_lengths
- )
-
- # Collect CTC branch stats
- stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
- stats["cer_ctc"] = cer_ctc
-
- # Intermediate CTC (optional)
- loss_interctc = 0.0
- if self.interctc_weight != 0.0 and intermediate_outs is not None:
- for layer_idx, intermediate_out in intermediate_outs:
- # we assume intermediate_out has the same length & padding
- # as those of encoder_out
- loss_ic, cer_ic = self._calc_ctc_loss(
- intermediate_out, encoder_out_lens, text, text_lengths
- )
- loss_interctc = loss_interctc + loss_ic
-
- # Collect Intermedaite CTC stats
- stats["loss_interctc_layer{}".format(layer_idx)] = (
- loss_ic.detach() if loss_ic is not None else None
- )
- stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
-
- loss_interctc = loss_interctc / len(intermediate_outs)
-
- # calculate whole encoder loss
- loss_ctc = (
- 1 - self.interctc_weight
- ) * loss_ctc + self.interctc_weight * loss_interctc
-
- # 2b. Attention decoder branch
- if self.ctc_weight != 1.0:
- loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
- encoder_out, encoder_out_lens, text, text_lengths
- )
-
- # 3. CTC-Att loss definition
- if self.ctc_weight == 0.0:
- loss = loss_att
- elif self.ctc_weight == 1.0:
- loss = loss_ctc
- else:
- loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
-
- # Collect Attn branch stats
- stats["loss_att"] = loss_att.detach() if loss_att is not None else None
- stats["acc"] = acc_att
- stats["cer"] = cer_att
- stats["wer"] = wer_att
+ stats["pse_loss"] = pse_loss.detach()
+ stats["pit_loss"] = pit_loss.detach()
+ stats["attractor_loss"] = attractor_loss.detach()
+ stats["batch_size"] = batch_size
# Collect total loss stats
stats["loss"] = torch.clone(loss.detach())
@@ -163,233 +152,70 @@
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
- def collect_feats(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- text: torch.Tensor,
- text_lengths: torch.Tensor,
- ) -> Dict[str, torch.Tensor]:
- if self.extract_feats_in_collect_stats:
- feats, feats_lengths = self._extract_feats(speech, speech_lengths)
+ def estimate_sequential(self,
+ speech: torch.Tensor,
+ n_speakers: int = None,
+ shuffle: bool = True,
+ threshold: float = 0.5,
+ **kwargs):
+ speech_lengths = torch.tensor([len(sph) for sph in speech]).to(torch.int64)
+ emb = self.forward_encoder(speech, speech_lengths)
+ if shuffle:
+ orders = [np.arange(e.shape[0]) for e in emb]
+ for order in orders:
+ np.random.shuffle(order)
+ attractors, probs = self.encoder_decoder_attractor.estimate(
+ [e[torch.from_numpy(order).to(torch.long).to(speech[0].device)] for e, order in zip(emb, orders)])
else:
- # Generate dummy stats if extract_feats_in_collect_stats is False
- logging.warning(
- "Generating dummy stats for feats and feats_lengths, "
- "because encoder_conf.extract_feats_in_collect_stats is "
- f"{self.extract_feats_in_collect_stats}"
- )
- feats, feats_lengths = speech, speech_lengths
- return {"feats": feats, "feats_lengths": feats_lengths}
+ attractors, probs = self.encoder_decoder_attractor.estimate(emb)
+ attractors_active = []
+ for p, att, e in zip(probs, attractors, emb):
+ if n_speakers and n_speakers >= 0:
+ att = att[:n_speakers, ]
+ attractors_active.append(att)
+ elif threshold is not None:
+ silence = torch.nonzero(p < threshold)[0]
+ n_spk = silence[0] if silence.size else None
+ att = att[:n_spk, ]
+ attractors_active.append(att)
+ else:
+ NotImplementedError('n_speakers or threshold has to be given.')
+ raw_n_speakers = [att.shape[0] for att in attractors_active]
+ attractors = [
+ pad_attractor(att, self.max_n_speaker) if att.shape[0] <= self.max_n_speaker else att[:self.max_n_speaker]
+ for att in attractors_active]
+ ys = [torch.matmul(e, att.permute(1, 0)) for e, att in zip(emb, attractors)]
+ logits = self.forward_post_net(ys, speech_lengths)
+ ys = [self.recover_y_from_powerlabel(logit, raw_n_speaker) for logit, raw_n_speaker in
+ zip(logits, raw_n_speakers)]
- def encode(
- self, speech: torch.Tensor, speech_lengths: torch.Tensor
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """Frontend + Encoder. Note that this method is used by asr_inference.py
+ return ys, emb, attractors, raw_n_speakers
- Args:
- speech: (Batch, Length, ...)
- speech_lengths: (Batch, )
- """
- with autocast(False):
- # 1. Extract feats
- feats, feats_lengths = self._extract_feats(speech, speech_lengths)
+ def recover_y_from_powerlabel(self, logit, n_speaker):
+ pred = torch.argmax(torch.softmax(logit, dim=-1), dim=-1)
+ oov_index = torch.where(pred == self.mapping_dict['oov'])[0]
+ for i in oov_index:
+ if i > 0:
+ pred[i] = pred[i - 1]
+ else:
+ pred[i] = 0
+ pred = [self.inv_mapping_func(i) for i in pred]
+ decisions = [bin(num)[2:].zfill(self.max_n_speaker)[::-1] for num in pred]
+ decisions = torch.from_numpy(
+ np.stack([np.array([int(i) for i in dec]) for dec in decisions], axis=0)).to(logit.device).to(
+ torch.float32)
+ decisions = decisions[:, :n_speaker]
+ return decisions
- # 2. Data augmentation
- if self.specaug is not None and self.training:
- feats, feats_lengths = self.specaug(feats, feats_lengths)
+ def inv_mapping_func(self, label):
- # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
- if self.normalize is not None:
- feats, feats_lengths = self.normalize(feats, feats_lengths)
-
- # Pre-encoder, e.g. used for raw input data
- if self.preencoder is not None:
- feats, feats_lengths = self.preencoder(feats, feats_lengths)
-
- # 4. Forward encoder
- # feats: (Batch, Length, Dim)
- # -> encoder_out: (Batch, Length2, Dim2)
- if self.encoder.interctc_use_conditioning:
- encoder_out, encoder_out_lens, _ = self.encoder(
- feats, feats_lengths, ctc=self.ctc
- )
+ if not isinstance(label, int):
+ label = int(label)
+ if label in self.mapping_dict['label2dec'].keys():
+ num = self.mapping_dict['label2dec'][label]
else:
- encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
- intermediate_outs = None
- if isinstance(encoder_out, tuple):
- intermediate_outs = encoder_out[1]
- encoder_out = encoder_out[0]
+ num = -1
+ return num
- # Post-encoder, e.g. NLU
- if self.postencoder is not None:
- encoder_out, encoder_out_lens = self.postencoder(
- encoder_out, encoder_out_lens
- )
-
- assert encoder_out.size(0) == speech.size(0), (
- encoder_out.size(),
- speech.size(0),
- )
- assert encoder_out.size(1) <= encoder_out_lens.max(), (
- encoder_out.size(),
- encoder_out_lens.max(),
- )
-
- if intermediate_outs is not None:
- return (encoder_out, intermediate_outs), encoder_out_lens
-
- return encoder_out, encoder_out_lens
-
- def _extract_feats(
- self, speech: torch.Tensor, speech_lengths: torch.Tensor
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- assert speech_lengths.dim() == 1, speech_lengths.shape
-
- # for data-parallel
- speech = speech[:, : speech_lengths.max()]
-
- if self.frontend is not None:
- # Frontend
- # e.g. STFT and Feature extract
- # data_loader may send time-domain signal in this case
- # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
- feats, feats_lengths = self.frontend(speech, speech_lengths)
- else:
- # No frontend and no feature extract
- feats, feats_lengths = speech, speech_lengths
- return feats, feats_lengths
-
- def nll(
- self,
- encoder_out: torch.Tensor,
- encoder_out_lens: torch.Tensor,
- ys_pad: torch.Tensor,
- ys_pad_lens: torch.Tensor,
- ) -> torch.Tensor:
- """Compute negative log likelihood(nll) from transformer-decoder
-
- Normally, this function is called in batchify_nll.
-
- Args:
- encoder_out: (Batch, Length, Dim)
- encoder_out_lens: (Batch,)
- ys_pad: (Batch, Length)
- ys_pad_lens: (Batch,)
- """
- ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
- ys_in_lens = ys_pad_lens + 1
-
- # 1. Forward decoder
- decoder_out, _ = self.decoder(
- encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
- ) # [batch, seqlen, dim]
- batch_size = decoder_out.size(0)
- decoder_num_class = decoder_out.size(2)
- # nll: negative log-likelihood
- nll = torch.nn.functional.cross_entropy(
- decoder_out.view(-1, decoder_num_class),
- ys_out_pad.view(-1),
- ignore_index=self.ignore_id,
- reduction="none",
- )
- nll = nll.view(batch_size, -1)
- nll = nll.sum(dim=1)
- assert nll.size(0) == batch_size
- return nll
-
- def batchify_nll(
- self,
- encoder_out: torch.Tensor,
- encoder_out_lens: torch.Tensor,
- ys_pad: torch.Tensor,
- ys_pad_lens: torch.Tensor,
- batch_size: int = 100,
- ):
- """Compute negative log likelihood(nll) from transformer-decoder
-
- To avoid OOM, this fuction seperate the input into batches.
- Then call nll for each batch and combine and return results.
- Args:
- encoder_out: (Batch, Length, Dim)
- encoder_out_lens: (Batch,)
- ys_pad: (Batch, Length)
- ys_pad_lens: (Batch,)
- batch_size: int, samples each batch contain when computing nll,
- you may change this to avoid OOM or increase
- GPU memory usage
- """
- total_num = encoder_out.size(0)
- if total_num <= batch_size:
- nll = self.nll(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
- else:
- nll = []
- start_idx = 0
- while True:
- end_idx = min(start_idx + batch_size, total_num)
- batch_encoder_out = encoder_out[start_idx:end_idx, :, :]
- batch_encoder_out_lens = encoder_out_lens[start_idx:end_idx]
- batch_ys_pad = ys_pad[start_idx:end_idx, :]
- batch_ys_pad_lens = ys_pad_lens[start_idx:end_idx]
- batch_nll = self.nll(
- batch_encoder_out,
- batch_encoder_out_lens,
- batch_ys_pad,
- batch_ys_pad_lens,
- )
- nll.append(batch_nll)
- start_idx = end_idx
- if start_idx == total_num:
- break
- nll = torch.cat(nll)
- assert nll.size(0) == total_num
- return nll
-
- def _calc_att_loss(
- self,
- encoder_out: torch.Tensor,
- encoder_out_lens: torch.Tensor,
- ys_pad: torch.Tensor,
- ys_pad_lens: torch.Tensor,
- ):
- ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
- ys_in_lens = ys_pad_lens + 1
-
- # 1. Forward decoder
- decoder_out, _ = self.decoder(
- encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
- )
-
- # 2. Compute attention loss
- loss_att = self.criterion_att(decoder_out, ys_out_pad)
- acc_att = th_accuracy(
- decoder_out.view(-1, self.vocab_size),
- ys_out_pad,
- ignore_label=self.ignore_id,
- )
-
- # Compute cer/wer using attention-decoder
- if self.training or self.error_calculator is None:
- cer_att, wer_att = None, None
- else:
- ys_hat = decoder_out.argmax(dim=-1)
- cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
-
- return loss_att, acc_att, cer_att, wer_att
-
- def _calc_ctc_loss(
- self,
- encoder_out: torch.Tensor,
- encoder_out_lens: torch.Tensor,
- ys_pad: torch.Tensor,
- ys_pad_lens: torch.Tensor,
- ):
- # Calc CTC loss
- loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
-
- # Calc CER using CTC
- cer_ctc = None
- if not self.training and self.error_calculator is not None:
- ys_hat = self.ctc.argmax(encoder_out).data
- cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
- return loss_ctc, cer_ctc
+ def collect_feats(self, **batch: torch.Tensor) -> Dict[str, torch.Tensor]:
+ pass
\ No newline at end of file
--
Gitblit v1.9.1