From 12dd848db2cfd0e2ae6f32cfb1a5aecdf0f77365 Mon Sep 17 00:00:00 2001
From: 嘉渊 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期二, 16 五月 2023 11:16:31 +0800
Subject: [PATCH] update repo
---
funasr/models/e2e_diar_eend_ola.py | 95 ++++++++++++++++++++++++++++++++---------------
1 files changed, 65 insertions(+), 30 deletions(-)
diff --git a/funasr/models/e2e_diar_eend_ola.py b/funasr/models/e2e_diar_eend_ola.py
index 5c1c9ce..da7c674 100644
--- a/funasr/models/e2e_diar_eend_ola.py
+++ b/funasr/models/e2e_diar_eend_ola.py
@@ -11,11 +11,12 @@
import torch.nn as nn
from typeguard import check_argument_types
-from funasr.modules.eend_ola.encoder import TransformerEncoder
+from funasr.models.frontend.wav_frontend import WavFrontendMel23
+from funasr.modules.eend_ola.encoder import EENDOLATransformerEncoder
from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor
from funasr.modules.eend_ola.utils.power import generate_mapping_dict
from funasr.torch_utils.device_funcs import force_gatherable
-from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.base_model import FunASRModel
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
pass
@@ -26,13 +27,21 @@
yield
-class DiarEENDOLAModel(AbsESPnetModel):
- """CTC-attention hybrid Encoder-Decoder model"""
+def pad_attractor(att, max_n_speakers):
+ C, D = att.shape
+ if C < max_n_speakers:
+ att = torch.cat([att, torch.zeros(max_n_speakers - C, D).to(torch.float32).to(att.device)], dim=0)
+ return att
+
+
+class DiarEENDOLAModel(FunASRModel):
+ """EEND-OLA diarization model"""
def __init__(
self,
- encoder: TransformerEncoder,
- eda: EncoderDecoderAttractor,
+ frontend: WavFrontendMel23,
+ encoder: EENDOLATransformerEncoder,
+ encoder_decoder_attractor: EncoderDecoderAttractor,
n_units: int = 256,
max_n_speaker: int = 8,
attractor_loss_weight: float = 1.0,
@@ -42,16 +51,37 @@
assert check_argument_types()
super().__init__()
- self.encoder = encoder
- self.eda = eda
+ self.frontend = frontend
+ self.enc = encoder
+ self.eda = encoder_decoder_attractor
self.attractor_loss_weight = attractor_loss_weight
self.max_n_speaker = max_n_speaker
if mapping_dict is None:
mapping_dict = generate_mapping_dict(max_speaker_num=self.max_n_speaker)
self.mapping_dict = mapping_dict
# PostNet
- self.PostNet = nn.LSTM(self.max_n_speaker, n_units, 1, batch_first=True)
+ self.postnet = nn.LSTM(self.max_n_speaker, n_units, 1, batch_first=True)
self.output_layer = nn.Linear(n_units, mapping_dict['oov'] + 1)
+
+ def forward_encoder(self, xs, ilens):
+ xs = nn.utils.rnn.pad_sequence(xs, batch_first=True, padding_value=-1)
+ pad_shape = xs.shape
+ xs_mask = [torch.ones(ilen).to(xs.device) for ilen in ilens]
+ xs_mask = torch.nn.utils.rnn.pad_sequence(xs_mask, batch_first=True, padding_value=0).unsqueeze(-2)
+ emb = self.enc(xs, xs_mask)
+ emb = torch.split(emb.view(pad_shape[0], pad_shape[1], -1), 1, dim=0)
+ emb = [e[0][:ilen] for e, ilen in zip(emb, ilens)]
+ return emb
+
+ def forward_post_net(self, logits, ilens):
+ maxlen = torch.max(ilens).to(torch.int).item()
+ logits = nn.utils.rnn.pad_sequence(logits, batch_first=True, padding_value=-1)
+ logits = nn.utils.rnn.pack_padded_sequence(logits, ilens.cpu().to(torch.int64), batch_first=True, enforce_sorted=False)
+ outputs, (_, _) = self.postnet(logits)
+ outputs = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True, padding_value=-1, total_length=maxlen)[0]
+ outputs = [output[:ilens[i].to(torch.int).item()] for i, output in enumerate(outputs)]
+ outputs = [self.output_layer(output) for output in outputs]
+ return outputs
def forward(
self,
@@ -61,7 +91,6 @@
text_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
-
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@@ -82,7 +111,7 @@
text = text[:, : text_lengths.max()]
# 1. Encoder
- encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+ encoder_out, encoder_out_lens = self.enc(speech, speech_lengths)
intermediate_outs = None
if isinstance(encoder_out, tuple):
intermediate_outs = encoder_out[1]
@@ -156,62 +185,68 @@
def estimate_sequential(self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
- n_speakers: int,
- shuffle: bool,
- threshold: float,
+ n_speakers: int = None,
+ shuffle: bool = True,
+ threshold: float = 0.5,
**kwargs):
speech = [s[:s_len] for s, s_len in zip(speech, speech_lengths)]
- emb = self.forward_core(speech) # list, [(T1, C1), ..., (T1, C1)]
+ emb = self.forward_encoder(speech, speech_lengths)
if shuffle:
orders = [np.arange(e.shape[0]) for e in emb]
for order in orders:
np.random.shuffle(order)
- # e[order]: shuffle鍚庣殑embeddings, list, [(T1, C1), ..., (T1, C1)] 姣忎釜sample鐨凾缁村害宸茶繘琛岄殢鏈洪『搴忎氦鎹�
- # attractors, list, hts(璁烘枃閲岀殑as), [(max_n_speakers, n_units), ..., (max_n_speakers, n_units)]
- # probs, list, [(max_n_speakers, ), ..., (max_n_speakers, ]
attractors, probs = self.eda.estimate(
- [e[torch.from_numpy(order).to(torch.long).to(xs[0].device)] for e, order in zip(emb, orders)])
+ [e[torch.from_numpy(order).to(torch.long).to(speech[0].device)] for e, order in zip(emb, orders)])
else:
attractors, probs = self.eda.estimate(emb)
attractors_active = []
for p, att, e in zip(probs, attractors, emb):
- if n_speakers and n_speakers >= 0: # 鏍规嵁鎸囧畾璇磋瘽浜烘暟, 閫夋嫨瀵瑰簲鏁伴噺鐨剏s
- # TODO锛氬湪娴嬭瘯鏈変笉鍚屾暟閲弒peaker鏁扮殑鏁版嵁闆嗘椂锛岃�冭檻鏀规垚鏍规嵁sample鏉ョ‘瀹氬叿浣撶殑speaker鏁帮紝鑰屼笉鏄洿鎺ユ寚瀹�
- # raise NotImplementedError
+ if n_speakers and n_speakers >= 0:
att = att[:n_speakers, ]
attractors_active.append(att)
elif threshold is not None:
- silence = torch.nonzero(p < threshold)[0] # 鎵惧埌绗竴涓緭鍑烘鐜囧皬浜庨槇鍊肩殑绱㈠紩, 浣滀负缁撴潫, 涓斿�煎垰濂界瓑浜庤璇濅汉鏁�
+ silence = torch.nonzero(p < threshold)[0]
n_spk = silence[0] if silence.size else None
att = att[:n_spk, ]
attractors_active.append(att)
else:
- NotImplementedError('n_speakers or th has to be given.')
- raw_n_speakers = [att.shape[0] for att in attractors_active] # [C1, C2, ..., CB]
+ NotImplementedError('n_speakers or threshold has to be given.')
+ raw_n_speakers = [att.shape[0] for att in attractors_active]
attractors = [
pad_attractor(att, self.max_n_speaker) if att.shape[0] <= self.max_n_speaker else att[:self.max_n_speaker]
for att in attractors_active]
ys = [torch.matmul(e, att.permute(1, 0)) for e, att in zip(emb, attractors)]
- # ys_eda = [torch.sigmoid(y[:, :n_spk]) for y,n_spk in zip(ys, raw_n_speakers)]
- logits = self.cal_postnet(ys, self.max_n_speaker)
+ logits = self.forward_post_net(ys, speech_lengths)
ys = [self.recover_y_from_powerlabel(logit, raw_n_speaker) for logit, raw_n_speaker in
zip(logits, raw_n_speakers)]
return ys, emb, attractors, raw_n_speakers
def recover_y_from_powerlabel(self, logit, n_speaker):
- pred = torch.argmax(torch.softmax(logit, dim=-1), dim=-1) # (T, )
+ pred = torch.argmax(torch.softmax(logit, dim=-1), dim=-1)
oov_index = torch.where(pred == self.mapping_dict['oov'])[0]
for i in oov_index:
if i > 0:
pred[i] = pred[i - 1]
else:
pred[i] = 0
- pred = [self.reporter.inv_mapping_func(i, self.mapping_dict) for i in pred]
- # print(pred)
+ pred = [self.inv_mapping_func(i) for i in pred]
decisions = [bin(num)[2:].zfill(self.max_n_speaker)[::-1] for num in pred]
decisions = torch.from_numpy(
np.stack([np.array([int(i) for i in dec]) for dec in decisions], axis=0)).to(logit.device).to(
torch.float32)
decisions = decisions[:, :n_speaker]
return decisions
+
+ def inv_mapping_func(self, label):
+
+ if not isinstance(label, int):
+ label = int(label)
+ if label in self.mapping_dict['label2dec'].keys():
+ num = self.mapping_dict['label2dec'][label]
+ else:
+ num = -1
+ return num
+
+ def collect_feats(self, **batch: torch.Tensor) -> Dict[str, torch.Tensor]:
+ pass
\ No newline at end of file
--
Gitblit v1.9.1