| | |
| | | # Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved. |
| | | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) |
| | | |
| | | import logging |
| | | import torch |
| | | from contextlib import contextmanager |
| | | from distutils.version import LooseVersion |
| | | from funasr.layers.abs_normalize import AbsNormalize |
| | | from funasr.losses.label_smoothing_loss import ( |
| | | LabelSmoothingLoss, # noqa: H301 |
| | | ) |
| | | from funasr.models.ctc import CTC |
| | | from funasr.models.decoder.abs_decoder import AbsDecoder |
| | | from funasr.models.encoder.abs_encoder import AbsEncoder |
| | | from funasr.models.frontend.abs_frontend import AbsFrontend |
| | | from funasr.models.postencoder.abs_postencoder import AbsPostEncoder |
| | | from funasr.models.preencoder.abs_preencoder import AbsPreEncoder |
| | | from funasr.models.specaug.abs_specaug import AbsSpecAug |
| | | from funasr.modules.add_sos_eos import add_sos_eos |
| | | from funasr.modules.e2e_asr_common import ErrorCalculator |
| | | from typing import Dict |
| | | from typing import Tuple |
| | | |
| | | import numpy as np |
| | | import torch |
| | | import torch.nn as nn |
| | | from typeguard import check_argument_types |
| | | |
| | | from funasr.modules.eend_ola.encoder import TransformerEncoder |
| | | from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor |
| | | from funasr.modules.eend_ola.utils.power import generate_mapping_dict |
| | | from funasr.modules.nets_utils import th_accuracy |
| | | from funasr.torch_utils.device_funcs import force_gatherable |
| | | from funasr.train.abs_espnet_model import AbsESPnetModel |
| | | from typeguard import check_argument_types |
| | | from typing import Dict |
| | | from typing import List |
| | | from typing import Optional |
| | | from typing import Tuple |
| | | from typing import Union |
| | | |
| | | if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): |
| | | from torch.cuda.amp import autocast |
| | | pass |
| | | else: |
| | | # Nothing to do if torch<1.6.0 |
| | | @contextmanager |
| | |
| | | self, |
| | | encoder: TransformerEncoder, |
| | | eda: EncoderDecoderAttractor, |
| | | n_units: int = 256, |
| | | max_n_speaker: int = 8, |
| | | attractor_loss_weight: float = 1.0, |
| | | mapping_dict=None, |
| | |
| | | if mapping_dict is None: |
| | | mapping_dict = generate_mapping_dict(max_speaker_num=self.max_n_speaker) |
| | | self.mapping_dict = mapping_dict |
| | | # PostNet |
| | | self.PostNet = nn.LSTM(self.max_n_speaker, n_units, 1, batch_first=True) |
| | | self.output_layer = nn.Linear(n_units, mapping_dict['oov'] + 1) |
| | | |
| | | def forward( |
| | | self, |
| | |
| | | loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) |
| | | return loss, stats, weight |
| | | |
| | | def collect_feats( |
| | | self, |
| | | def estimate_sequential(self, |
| | | speech: torch.Tensor, |
| | | speech_lengths: torch.Tensor, |
| | | text: torch.Tensor, |
| | | text_lengths: torch.Tensor, |
| | | ) -> Dict[str, torch.Tensor]: |
| | | if self.extract_feats_in_collect_stats: |
| | | feats, feats_lengths = self._extract_feats(speech, speech_lengths) |
| | | n_speakers: int, |
| | | shuffle: bool, |
| | | threshold: float, |
| | | **kwargs): |
| | | speech = [s[:s_len] for s, s_len in zip(speech, speech_lengths)] |
| | | emb = self.forward_core(speech) # list, [(T1, C1), ..., (T1, C1)] |
| | | if shuffle: |
| | | orders = [np.arange(e.shape[0]) for e in emb] |
| | | for order in orders: |
| | | np.random.shuffle(order) |
| | | # e[order]: shuffle后的embeddings, list, [(T1, C1), ..., (T1, C1)] 每个sample的T维度已进行随机顺序交换 |
| | | # attractors, list, hts(论文里的as), [(max_n_speakers, n_units), ..., (max_n_speakers, n_units)] |
| | | # probs, list, [(max_n_speakers, ), ..., (max_n_speakers, ] |
| | | attractors, probs = self.eda.estimate( |
| | | [e[torch.from_numpy(order).to(torch.long).to(xs[0].device)] for e, order in zip(emb, orders)]) |
| | | else: |
| | | # Generate dummy stats if extract_feats_in_collect_stats is False |
| | | logging.warning( |
| | | "Generating dummy stats for feats and feats_lengths, " |
| | | "because encoder_conf.extract_feats_in_collect_stats is " |
| | | f"{self.extract_feats_in_collect_stats}" |
| | | ) |
| | | feats, feats_lengths = speech, speech_lengths |
| | | return {"feats": feats, "feats_lengths": feats_lengths} |
| | | |
| | | def encode( |
| | | self, speech: torch.Tensor, speech_lengths: torch.Tensor |
| | | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | """Frontend + Encoder. Note that this method is used by asr_inference.py |
| | | |
| | | Args: |
| | | speech: (Batch, Length, ...) |
| | | speech_lengths: (Batch, ) |
| | | """ |
| | | with autocast(False): |
| | | # 1. Extract feats |
| | | feats, feats_lengths = self._extract_feats(speech, speech_lengths) |
| | | |
| | | # 2. Data augmentation |
| | | if self.specaug is not None and self.training: |
| | | feats, feats_lengths = self.specaug(feats, feats_lengths) |
| | | |
| | | # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN |
| | | if self.normalize is not None: |
| | | feats, feats_lengths = self.normalize(feats, feats_lengths) |
| | | |
| | | # Pre-encoder, e.g. used for raw input data |
| | | if self.preencoder is not None: |
| | | feats, feats_lengths = self.preencoder(feats, feats_lengths) |
| | | |
| | | # 4. Forward encoder |
| | | # feats: (Batch, Length, Dim) |
| | | # -> encoder_out: (Batch, Length2, Dim2) |
| | | if self.encoder.interctc_use_conditioning: |
| | | encoder_out, encoder_out_lens, _ = self.encoder( |
| | | feats, feats_lengths, ctc=self.ctc |
| | | ) |
| | | attractors, probs = self.eda.estimate(emb) |
| | | attractors_active = [] |
| | | for p, att, e in zip(probs, attractors, emb): |
| | | if n_speakers and n_speakers >= 0: # 根据指定说话人数, 选择对应数量的ys |
| | | # TODO:在测试有不同数量speaker数的数据集时,考虑改成根据sample来确定具体的speaker数,而不是直接指定 |
| | | # raise NotImplementedError |
| | | att = att[:n_speakers, ] |
| | | attractors_active.append(att) |
| | | elif threshold is not None: |
| | | silence = torch.nonzero(p < threshold)[0] # 找到第一个输出概率小于阈值的索引, 作为结束, 且值刚好等于说话人数 |
| | | n_spk = silence[0] if silence.size else None |
| | | att = att[:n_spk, ] |
| | | attractors_active.append(att) |
| | | else: |
| | | encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths) |
| | | intermediate_outs = None |
| | | if isinstance(encoder_out, tuple): |
| | | intermediate_outs = encoder_out[1] |
| | | encoder_out = encoder_out[0] |
| | | NotImplementedError('n_speakers or th has to be given.') |
| | | raw_n_speakers = [att.shape[0] for att in attractors_active] # [C1, C2, ..., CB] |
| | | attractors = [ |
| | | pad_attractor(att, self.max_n_speaker) if att.shape[0] <= self.max_n_speaker else att[:self.max_n_speaker] |
| | | for att in attractors_active] |
| | | ys = [torch.matmul(e, att.permute(1, 0)) for e, att in zip(emb, attractors)] |
| | | # ys_eda = [torch.sigmoid(y[:, :n_spk]) for y,n_spk in zip(ys, raw_n_speakers)] |
| | | logits = self.cal_postnet(ys, self.max_n_speaker) |
| | | ys = [self.recover_y_from_powerlabel(logit, raw_n_speaker) for logit, raw_n_speaker in |
| | | zip(logits, raw_n_speakers)] |
| | | |
| | | # Post-encoder, e.g. NLU |
| | | if self.postencoder is not None: |
| | | encoder_out, encoder_out_lens = self.postencoder( |
| | | encoder_out, encoder_out_lens |
| | | ) |
| | | return ys, emb, attractors, raw_n_speakers |
| | | |
| | | assert encoder_out.size(0) == speech.size(0), ( |
| | | encoder_out.size(), |
| | | speech.size(0), |
| | | ) |
| | | assert encoder_out.size(1) <= encoder_out_lens.max(), ( |
| | | encoder_out.size(), |
| | | encoder_out_lens.max(), |
| | | ) |
| | | |
| | | if intermediate_outs is not None: |
| | | return (encoder_out, intermediate_outs), encoder_out_lens |
| | | |
| | | return encoder_out, encoder_out_lens |
| | | |
| | | def _extract_feats( |
| | | self, speech: torch.Tensor, speech_lengths: torch.Tensor |
| | | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | assert speech_lengths.dim() == 1, speech_lengths.shape |
| | | |
| | | # for data-parallel |
| | | speech = speech[:, : speech_lengths.max()] |
| | | |
| | | if self.frontend is not None: |
| | | # Frontend |
| | | # e.g. STFT and Feature extract |
| | | # data_loader may send time-domain signal in this case |
| | | # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim) |
| | | feats, feats_lengths = self.frontend(speech, speech_lengths) |
| | | def recover_y_from_powerlabel(self, logit, n_speaker): |
| | | pred = torch.argmax(torch.softmax(logit, dim=-1), dim=-1) # (T, ) |
| | | oov_index = torch.where(pred == self.mapping_dict['oov'])[0] |
| | | for i in oov_index: |
| | | if i > 0: |
| | | pred[i] = pred[i - 1] |
| | | else: |
| | | # No frontend and no feature extract |
| | | feats, feats_lengths = speech, speech_lengths |
| | | return feats, feats_lengths |
| | | |
| | | def nll( |
| | | self, |
| | | encoder_out: torch.Tensor, |
| | | encoder_out_lens: torch.Tensor, |
| | | ys_pad: torch.Tensor, |
| | | ys_pad_lens: torch.Tensor, |
| | | ) -> torch.Tensor: |
| | | """Compute negative log likelihood(nll) from transformer-decoder |
| | | |
| | | Normally, this function is called in batchify_nll. |
| | | |
| | | Args: |
| | | encoder_out: (Batch, Length, Dim) |
| | | encoder_out_lens: (Batch,) |
| | | ys_pad: (Batch, Length) |
| | | ys_pad_lens: (Batch,) |
| | | """ |
| | | ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) |
| | | ys_in_lens = ys_pad_lens + 1 |
| | | |
| | | # 1. Forward decoder |
| | | decoder_out, _ = self.decoder( |
| | | encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens |
| | | ) # [batch, seqlen, dim] |
| | | batch_size = decoder_out.size(0) |
| | | decoder_num_class = decoder_out.size(2) |
| | | # nll: negative log-likelihood |
| | | nll = torch.nn.functional.cross_entropy( |
| | | decoder_out.view(-1, decoder_num_class), |
| | | ys_out_pad.view(-1), |
| | | ignore_index=self.ignore_id, |
| | | reduction="none", |
| | | ) |
| | | nll = nll.view(batch_size, -1) |
| | | nll = nll.sum(dim=1) |
| | | assert nll.size(0) == batch_size |
| | | return nll |
| | | |
| | | def batchify_nll( |
| | | self, |
| | | encoder_out: torch.Tensor, |
| | | encoder_out_lens: torch.Tensor, |
| | | ys_pad: torch.Tensor, |
| | | ys_pad_lens: torch.Tensor, |
| | | batch_size: int = 100, |
| | | ): |
| | | """Compute negative log likelihood(nll) from transformer-decoder |
| | | |
| | | To avoid OOM, this fuction seperate the input into batches. |
| | | Then call nll for each batch and combine and return results. |
| | | Args: |
| | | encoder_out: (Batch, Length, Dim) |
| | | encoder_out_lens: (Batch,) |
| | | ys_pad: (Batch, Length) |
| | | ys_pad_lens: (Batch,) |
| | | batch_size: int, samples each batch contain when computing nll, |
| | | you may change this to avoid OOM or increase |
| | | GPU memory usage |
| | | """ |
| | | total_num = encoder_out.size(0) |
| | | if total_num <= batch_size: |
| | | nll = self.nll(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens) |
| | | else: |
| | | nll = [] |
| | | start_idx = 0 |
| | | while True: |
| | | end_idx = min(start_idx + batch_size, total_num) |
| | | batch_encoder_out = encoder_out[start_idx:end_idx, :, :] |
| | | batch_encoder_out_lens = encoder_out_lens[start_idx:end_idx] |
| | | batch_ys_pad = ys_pad[start_idx:end_idx, :] |
| | | batch_ys_pad_lens = ys_pad_lens[start_idx:end_idx] |
| | | batch_nll = self.nll( |
| | | batch_encoder_out, |
| | | batch_encoder_out_lens, |
| | | batch_ys_pad, |
| | | batch_ys_pad_lens, |
| | | ) |
| | | nll.append(batch_nll) |
| | | start_idx = end_idx |
| | | if start_idx == total_num: |
| | | break |
| | | nll = torch.cat(nll) |
| | | assert nll.size(0) == total_num |
| | | return nll |
| | | |
| | | def _calc_att_loss( |
| | | self, |
| | | encoder_out: torch.Tensor, |
| | | encoder_out_lens: torch.Tensor, |
| | | ys_pad: torch.Tensor, |
| | | ys_pad_lens: torch.Tensor, |
| | | ): |
| | | ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) |
| | | ys_in_lens = ys_pad_lens + 1 |
| | | |
| | | # 1. Forward decoder |
| | | decoder_out, _ = self.decoder( |
| | | encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens |
| | | ) |
| | | |
| | | # 2. Compute attention loss |
| | | loss_att = self.criterion_att(decoder_out, ys_out_pad) |
| | | acc_att = th_accuracy( |
| | | decoder_out.view(-1, self.vocab_size), |
| | | ys_out_pad, |
| | | ignore_label=self.ignore_id, |
| | | ) |
| | | |
| | | # Compute cer/wer using attention-decoder |
| | | if self.training or self.error_calculator is None: |
| | | cer_att, wer_att = None, None |
| | | else: |
| | | ys_hat = decoder_out.argmax(dim=-1) |
| | | cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) |
| | | |
| | | return loss_att, acc_att, cer_att, wer_att |
| | | |
| | | def _calc_ctc_loss( |
| | | self, |
| | | encoder_out: torch.Tensor, |
| | | encoder_out_lens: torch.Tensor, |
| | | ys_pad: torch.Tensor, |
| | | ys_pad_lens: torch.Tensor, |
| | | ): |
| | | # Calc CTC loss |
| | | loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens) |
| | | |
| | | # Calc CER using CTC |
| | | cer_ctc = None |
| | | if not self.training and self.error_calculator is not None: |
| | | ys_hat = self.ctc.argmax(encoder_out).data |
| | | cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) |
| | | return loss_ctc, cer_ctc |
| | | pred[i] = 0 |
| | | pred = [self.reporter.inv_mapping_func(i, self.mapping_dict) for i in pred] |
| | | # print(pred) |
| | | decisions = [bin(num)[2:].zfill(self.max_n_speaker)[::-1] for num in pred] |
| | | decisions = torch.from_numpy( |
| | | np.stack([np.array([int(i) for i in dec]) for dec in decisions], axis=0)).to(logit.device).to( |
| | | torch.float32) |
| | | decisions = decisions[:, :n_speaker] |
| | | return decisions |