From f8d1c79fe355efb18ae49e4363307dfec3ab89ce Mon Sep 17 00:00:00 2001
From: 雾聪 <wucong.lyb@alibaba-inc.com>
Date: 星期一, 07 八月 2023 16:14:11 +0800
Subject: [PATCH] Merge branch 'main' of https://github.com/alibaba-damo-academy/FunASR into main
---
funasr/models/e2e_diar_eend_ola.py | 165 ++++++++++++++++++++++--------------------------------
1 files changed, 68 insertions(+), 97 deletions(-)
diff --git a/funasr/models/e2e_diar_eend_ola.py b/funasr/models/e2e_diar_eend_ola.py
index ae3a436..a0b545a 100644
--- a/funasr/models/e2e_diar_eend_ola.py
+++ b/funasr/models/e2e_diar_eend_ola.py
@@ -1,21 +1,20 @@
-# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
-# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
-
from contextlib import contextmanager
from distutils.version import LooseVersion
-from typing import Dict
-from typing import Tuple
+from typing import Dict, List, Tuple, Optional
import numpy as np
import torch
import torch.nn as nn
+import torch.nn.functional as F
+from funasr.models.base_model import FunASRModel
from funasr.models.frontend.wav_frontend import WavFrontendMel23
from funasr.modules.eend_ola.encoder import EENDOLATransformerEncoder
from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor
+from funasr.modules.eend_ola.utils.losses import standard_loss, cal_power_loss, fast_batch_pit_n_speaker_loss
+from funasr.modules.eend_ola.utils.power import create_powerlabel
from funasr.modules.eend_ola.utils.power import generate_mapping_dict
from funasr.torch_utils.device_funcs import force_gatherable
-from funasr.models.base_model import FunASRModel
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
pass
@@ -33,12 +32,35 @@
return att
+def pad_labels(ts, out_size):
+ for i, t in enumerate(ts):
+ if t.shape[1] < out_size:
+ ts[i] = F.pad(
+ t,
+ (0, out_size - t.shape[1], 0, 0),
+ mode='constant',
+ value=0.
+ )
+ return ts
+
+
+def pad_results(ys, out_size):
+ ys_padded = []
+ for i, y in enumerate(ys):
+ if y.shape[1] < out_size:
+ ys_padded.append(
+ torch.cat([y, torch.zeros(y.shape[0], out_size - y.shape[1]).to(torch.float32).to(y.device)], dim=1))
+ else:
+ ys_padded.append(y)
+ return ys_padded
+
+
class DiarEENDOLAModel(FunASRModel):
"""EEND-OLA diarization model"""
def __init__(
self,
- frontend: WavFrontendMel23,
+ frontend: Optional[WavFrontendMel23],
encoder: EENDOLATransformerEncoder,
encoder_decoder_attractor: EncoderDecoderAttractor,
n_units: int = 256,
@@ -47,11 +69,10 @@
mapping_dict=None,
**kwargs,
):
-
super().__init__()
self.frontend = frontend
self.enc = encoder
- self.eda = encoder_decoder_attractor
+ self.encoder_decoder_attractor = encoder_decoder_attractor
self.attractor_loss_weight = attractor_loss_weight
self.max_n_speaker = max_n_speaker
if mapping_dict is None:
@@ -74,7 +95,8 @@
def forward_post_net(self, logits, ilens):
maxlen = torch.max(ilens).to(torch.int).item()
logits = nn.utils.rnn.pad_sequence(logits, batch_first=True, padding_value=-1)
- logits = nn.utils.rnn.pack_padded_sequence(logits, ilens.cpu().to(torch.int64), batch_first=True, enforce_sorted=False)
+ logits = nn.utils.rnn.pack_padded_sequence(logits, ilens.cpu().to(torch.int64), batch_first=True,
+ enforce_sorted=False)
outputs, (_, _) = self.postnet(logits)
outputs = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True, padding_value=-1, total_length=maxlen)[0]
outputs = [output[:ilens[i].to(torch.int).item()] for i, output in enumerate(outputs)]
@@ -83,95 +105,45 @@
def forward(
self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- text: torch.Tensor,
- text_lengths: torch.Tensor,
+ speech: List[torch.Tensor],
+ speaker_labels: List[torch.Tensor],
+ orders: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
- """Frontend + Encoder + Decoder + Calc loss
- Args:
- speech: (Batch, Length, ...)
- speech_lengths: (Batch, )
- text: (Batch, Length)
- text_lengths: (Batch,)
- """
- assert text_lengths.dim() == 1, text_lengths.shape
+
# Check that batch_size is unified
- assert (
- speech.shape[0]
- == speech_lengths.shape[0]
- == text.shape[0]
- == text_lengths.shape[0]
- ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
- batch_size = speech.shape[0]
+ assert (len(speech) == len(speaker_labels)), (len(speech), len(speaker_labels))
+ speech_lengths = torch.tensor([len(sph) for sph in speech]).to(torch.int64)
+ speaker_labels_lengths = torch.tensor([spk.shape[-1] for spk in speaker_labels]).to(torch.int64)
+ batch_size = len(speech)
- # for data-parallel
- text = text[:, : text_lengths.max()]
+ # Encoder
+ encoder_out = self.forward_encoder(speech, speech_lengths)
- # 1. Encoder
- encoder_out, encoder_out_lens = self.enc(speech, speech_lengths)
- intermediate_outs = None
- if isinstance(encoder_out, tuple):
- intermediate_outs = encoder_out[1]
- encoder_out = encoder_out[0]
+ # Encoder-decoder attractor
+ attractor_loss, attractors = self.encoder_decoder_attractor([e[order] for e, order in zip(encoder_out, orders)],
+ speaker_labels_lengths)
+ speaker_logits = [torch.matmul(e, att.permute(1, 0)) for e, att in zip(encoder_out, attractors)]
- loss_att, acc_att, cer_att, wer_att = None, None, None, None
- loss_ctc, cer_ctc = None, None
+ # pit loss
+ pit_speaker_labels = fast_batch_pit_n_speaker_loss(speaker_logits, speaker_labels)
+ pit_loss = standard_loss(speaker_logits, pit_speaker_labels)
+
+ # pse loss
+ with torch.no_grad():
+ power_ts = [create_powerlabel(label.cpu().numpy(), self.mapping_dict, self.max_n_speaker).
+ to(encoder_out[0].device, non_blocking=True) for label in pit_speaker_labels]
+ pad_attractors = [pad_attractor(att, self.max_n_speaker) for att in attractors]
+ pse_speaker_logits = [torch.matmul(e, pad_att.permute(1, 0)) for e, pad_att in zip(encoder_out, pad_attractors)]
+ pse_speaker_logits = self.forward_post_net(pse_speaker_logits, speech_lengths)
+ pse_loss = cal_power_loss(pse_speaker_logits, power_ts)
+
+ loss = pse_loss + pit_loss + self.attractor_loss_weight * attractor_loss
+
stats = dict()
-
- # 1. CTC branch
- if self.ctc_weight != 0.0:
- loss_ctc, cer_ctc = self._calc_ctc_loss(
- encoder_out, encoder_out_lens, text, text_lengths
- )
-
- # Collect CTC branch stats
- stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
- stats["cer_ctc"] = cer_ctc
-
- # Intermediate CTC (optional)
- loss_interctc = 0.0
- if self.interctc_weight != 0.0 and intermediate_outs is not None:
- for layer_idx, intermediate_out in intermediate_outs:
- # we assume intermediate_out has the same length & padding
- # as those of encoder_out
- loss_ic, cer_ic = self._calc_ctc_loss(
- intermediate_out, encoder_out_lens, text, text_lengths
- )
- loss_interctc = loss_interctc + loss_ic
-
- # Collect Intermedaite CTC stats
- stats["loss_interctc_layer{}".format(layer_idx)] = (
- loss_ic.detach() if loss_ic is not None else None
- )
- stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
-
- loss_interctc = loss_interctc / len(intermediate_outs)
-
- # calculate whole encoder loss
- loss_ctc = (
- 1 - self.interctc_weight
- ) * loss_ctc + self.interctc_weight * loss_interctc
-
- # 2b. Attention decoder branch
- if self.ctc_weight != 1.0:
- loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
- encoder_out, encoder_out_lens, text, text_lengths
- )
-
- # 3. CTC-Att loss definition
- if self.ctc_weight == 0.0:
- loss = loss_att
- elif self.ctc_weight == 1.0:
- loss = loss_ctc
- else:
- loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
-
- # Collect Attn branch stats
- stats["loss_att"] = loss_att.detach() if loss_att is not None else None
- stats["acc"] = acc_att
- stats["cer"] = cer_att
- stats["wer"] = wer_att
+ stats["pse_loss"] = pse_loss.detach()
+ stats["pit_loss"] = pit_loss.detach()
+ stats["attractor_loss"] = attractor_loss.detach()
+ stats["batch_size"] = batch_size
# Collect total loss stats
stats["loss"] = torch.clone(loss.detach())
@@ -182,21 +154,20 @@
def estimate_sequential(self,
speech: torch.Tensor,
- speech_lengths: torch.Tensor,
n_speakers: int = None,
shuffle: bool = True,
threshold: float = 0.5,
**kwargs):
- speech = [s[:s_len] for s, s_len in zip(speech, speech_lengths)]
+ speech_lengths = torch.tensor([len(sph) for sph in speech]).to(torch.int64)
emb = self.forward_encoder(speech, speech_lengths)
if shuffle:
orders = [np.arange(e.shape[0]) for e in emb]
for order in orders:
np.random.shuffle(order)
- attractors, probs = self.eda.estimate(
+ attractors, probs = self.encoder_decoder_attractor.estimate(
[e[torch.from_numpy(order).to(torch.long).to(speech[0].device)] for e, order in zip(emb, orders)])
else:
- attractors, probs = self.eda.estimate(emb)
+ attractors, probs = self.encoder_decoder_attractor.estimate(emb)
attractors_active = []
for p, att, e in zip(probs, attractors, emb):
if n_speakers and n_speakers >= 0:
--
Gitblit v1.9.1