#!/usr/bin/env python3 # -*- encoding: utf-8 -*- # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. # MIT License (https://opensource.org/licenses/MIT) import time import torch import logging from torch.cuda.amp import autocast from typing import Union, Dict, List, Tuple, Optional from funasr.register import tables from funasr.models.ctc.ctc import CTC from funasr.utils import postprocess_utils from funasr.metrics.compute_acc import th_accuracy from funasr.utils.datadir_writer import DatadirWriter from funasr.models.paraformer.search import Hypothesis from funasr.models.paraformer.cif_predictor import mae_loss from funasr.train_utils.device_funcs import force_gatherable from funasr.losses.label_smoothing_loss import LabelSmoothingLoss from funasr.models.transformer.utils.add_sos_eos import add_sos_eos from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank @tables.register("model_classes", "UniASR") class UniASR(torch.nn.Module): """ Author: Speech Lab of DAMO Academy, Alibaba Group """ def __init__( self, specaug: Optional[str] = None, specaug_conf: Optional[Dict] = None, normalize: str = None, normalize_conf: Optional[Dict] = None, encoder: str = None, encoder_conf: Optional[Dict] = None, decoder: str = None, decoder_conf: Optional[Dict] = None, ctc: str = None, ctc_conf: Optional[Dict] = None, predictor: str = None, predictor_conf: Optional[Dict] = None, ctc_weight: float = 0.5, input_size: int = 80, vocab_size: int = -1, ignore_id: int = -1, blank_id: int = 0, sos: int = 1, eos: int = 2, lsm_weight: float = 0.0, length_normalized_loss: bool = False, # report_cer: bool = True, # report_wer: bool = True, # sym_space: str = "", # sym_blank: str = "", # extract_feats_in_collect_stats: bool = True, # predictor=None, predictor_weight: float = 0.0, predictor_bias: int = 0, sampling_ratio: float = 0.2, share_embedding: bool = False, # preencoder: Optional[AbsPreEncoder] = None, # postencoder: Optional[AbsPostEncoder] = None, use_1st_decoder_loss: bool = False, encoder1_encoder2_joint_training: bool = True, **kwargs, ): assert 0.0 <= ctc_weight <= 1.0, ctc_weight assert 0.0 <= interctc_weight < 1.0, interctc_weight super().__init__() self.blank_id = 0 self.sos = 1 self.eos = 2 self.vocab_size = vocab_size self.ignore_id = ignore_id self.ctc_weight = ctc_weight self.interctc_weight = interctc_weight self.token_list = token_list.copy() self.frontend = frontend self.specaug = specaug self.normalize = normalize self.preencoder = preencoder self.postencoder = postencoder self.encoder = encoder if not hasattr(self.encoder, "interctc_use_conditioning"): self.encoder.interctc_use_conditioning = False if self.encoder.interctc_use_conditioning: self.encoder.conditioning_layer = torch.nn.Linear( vocab_size, self.encoder.output_size() ) self.error_calculator = None # we set self.decoder = None in the CTC mode since # self.decoder parameters were never used and PyTorch complained # and threw an Exception in the multi-GPU experiment. # thanks Jeff Farris for pointing out the issue. if ctc_weight == 1.0: self.decoder = None else: self.decoder = decoder self.criterion_att = LabelSmoothingLoss( size=vocab_size, padding_idx=ignore_id, smoothing=lsm_weight, normalize_length=length_normalized_loss, ) if report_cer or report_wer: self.error_calculator = ErrorCalculator( token_list, sym_space, sym_blank, report_cer, report_wer ) if ctc_weight == 0.0: self.ctc = None else: self.ctc = ctc self.extract_feats_in_collect_stats = extract_feats_in_collect_stats self.predictor = predictor self.predictor_weight = predictor_weight self.criterion_pre = mae_loss(normalize_length=length_normalized_loss) self.step_cur = 0 if self.encoder.overlap_chunk_cls is not None: from funasr.models.scama.chunk_utilis import build_scama_mask_for_cross_attention_decoder self.build_scama_mask_for_cross_attention_decoder_fn = build_scama_mask_for_cross_attention_decoder self.decoder_attention_chunk_type = decoder_attention_chunk_type self.encoder2 = encoder2 self.decoder2 = decoder2 self.ctc_weight2 = ctc_weight2 if ctc_weight2 == 0.0: self.ctc2 = None else: self.ctc2 = ctc2 self.interctc_weight2 = interctc_weight2 self.predictor2 = predictor2 self.predictor_weight2 = predictor_weight2 self.decoder_attention_chunk_type2 = decoder_attention_chunk_type2 self.stride_conv = stride_conv self.loss_weight_model1 = loss_weight_model1 if self.encoder2.overlap_chunk_cls is not None: from funasr.models.scama.chunk_utilis import build_scama_mask_for_cross_attention_decoder self.build_scama_mask_for_cross_attention_decoder_fn2 = build_scama_mask_for_cross_attention_decoder self.decoder_attention_chunk_type2 = decoder_attention_chunk_type2 self.enable_maas_finetune = enable_maas_finetune self.freeze_encoder2 = freeze_encoder2 self.encoder1_encoder2_joint_training = encoder1_encoder2_joint_training self.length_normalized_loss = length_normalized_loss def forward( self, speech: torch.Tensor, speech_lengths: torch.Tensor, text: torch.Tensor, text_lengths: torch.Tensor, decoding_ind: int = None, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Frontend + Encoder + Decoder + Calc loss Args: speech: (Batch, Length, ...) speech_lengths: (Batch, ) text: (Batch, Length) text_lengths: (Batch,) """ assert text_lengths.dim() == 1, text_lengths.shape # Check that batch_size is unified assert ( speech.shape[0] == speech_lengths.shape[0] == text.shape[0] == text_lengths.shape[0] ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape) batch_size = speech.shape[0] # for data-parallel text = text[:, : text_lengths.max()] speech = speech[:, :speech_lengths.max()] ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind) # 1. Encoder if self.enable_maas_finetune: with torch.no_grad(): speech_raw, encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind) else: speech_raw, encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind) intermediate_outs = None if isinstance(encoder_out, tuple): intermediate_outs = encoder_out[1] encoder_out = encoder_out[0] loss_att, acc_att, cer_att, wer_att = None, None, None, None loss_ctc, cer_ctc = None, None stats = dict() loss_pre = None loss, loss1, loss2 = 0.0, 0.0, 0.0 if self.loss_weight_model1 > 0.0: ## model1 # 1. CTC branch if self.enable_maas_finetune: with torch.no_grad(): if self.ctc_weight != 0.0: if self.encoder.overlap_chunk_cls is not None: encoder_out_ctc, encoder_out_lens_ctc = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out, encoder_out_lens, chunk_outs=None) loss_ctc, cer_ctc = self._calc_ctc_loss( encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths ) # Collect CTC branch stats stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None stats["cer_ctc"] = cer_ctc # Intermediate CTC (optional) loss_interctc = 0.0 if self.interctc_weight != 0.0 and intermediate_outs is not None: for layer_idx, intermediate_out in intermediate_outs: # we assume intermediate_out has the same length & padding # as those of encoder_out if self.encoder.overlap_chunk_cls is not None: encoder_out_ctc, encoder_out_lens_ctc = \ self.encoder.overlap_chunk_cls.remove_chunk( intermediate_out, encoder_out_lens, chunk_outs=None) loss_ic, cer_ic = self._calc_ctc_loss( encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths ) loss_interctc = loss_interctc + loss_ic # Collect Intermedaite CTC stats stats["loss_interctc_layer{}".format(layer_idx)] = ( loss_ic.detach() if loss_ic is not None else None ) stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic loss_interctc = loss_interctc / len(intermediate_outs) # calculate whole encoder loss loss_ctc = ( 1 - self.interctc_weight ) * loss_ctc + self.interctc_weight * loss_interctc # 2b. Attention decoder branch if self.ctc_weight != 1.0: loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_predictor_loss( encoder_out, encoder_out_lens, text, text_lengths ) # 3. CTC-Att loss definition if self.ctc_weight == 0.0: loss = loss_att + loss_pre * self.predictor_weight elif self.ctc_weight == 1.0: loss = loss_ctc else: loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight # Collect Attn branch stats stats["loss_att"] = loss_att.detach() if loss_att is not None else None stats["acc"] = acc_att stats["cer"] = cer_att stats["wer"] = wer_att stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None else: if self.ctc_weight != 0.0: if self.encoder.overlap_chunk_cls is not None: encoder_out_ctc, encoder_out_lens_ctc = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out, encoder_out_lens, chunk_outs=None) loss_ctc, cer_ctc = self._calc_ctc_loss( encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths ) # Collect CTC branch stats stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None stats["cer_ctc"] = cer_ctc # Intermediate CTC (optional) loss_interctc = 0.0 if self.interctc_weight != 0.0 and intermediate_outs is not None: for layer_idx, intermediate_out in intermediate_outs: # we assume intermediate_out has the same length & padding # as those of encoder_out if self.encoder.overlap_chunk_cls is not None: encoder_out_ctc, encoder_out_lens_ctc = \ self.encoder.overlap_chunk_cls.remove_chunk( intermediate_out, encoder_out_lens, chunk_outs=None) loss_ic, cer_ic = self._calc_ctc_loss( encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths ) loss_interctc = loss_interctc + loss_ic # Collect Intermedaite CTC stats stats["loss_interctc_layer{}".format(layer_idx)] = ( loss_ic.detach() if loss_ic is not None else None ) stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic loss_interctc = loss_interctc / len(intermediate_outs) # calculate whole encoder loss loss_ctc = ( 1 - self.interctc_weight ) * loss_ctc + self.interctc_weight * loss_interctc # 2b. Attention decoder branch if self.ctc_weight != 1.0: loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_predictor_loss( encoder_out, encoder_out_lens, text, text_lengths ) # 3. CTC-Att loss definition if self.ctc_weight == 0.0: loss = loss_att + loss_pre * self.predictor_weight elif self.ctc_weight == 1.0: loss = loss_ctc else: loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight # Collect Attn branch stats stats["loss_att"] = loss_att.detach() if loss_att is not None else None stats["acc"] = acc_att stats["cer"] = cer_att stats["wer"] = wer_att stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None loss1 = loss if self.loss_weight_model1 < 1.0: ## model2 # encoder2 if self.freeze_encoder2: with torch.no_grad(): encoder_out, encoder_out_lens = self.encode2(encoder_out, encoder_out_lens, speech_raw, speech_lengths, ind=ind) else: encoder_out, encoder_out_lens = self.encode2(encoder_out, encoder_out_lens, speech_raw, speech_lengths, ind=ind) intermediate_outs = None if isinstance(encoder_out, tuple): intermediate_outs = encoder_out[1] encoder_out = encoder_out[0] # CTC2 if self.ctc_weight2 != 0.0: if self.encoder2.overlap_chunk_cls is not None: encoder_out_ctc, encoder_out_lens_ctc = \ self.encoder2.overlap_chunk_cls.remove_chunk( encoder_out, encoder_out_lens, chunk_outs=None, ) loss_ctc, cer_ctc = self._calc_ctc_loss2( encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths ) # Collect CTC branch stats stats["loss_ctc2"] = loss_ctc.detach() if loss_ctc is not None else None stats["cer_ctc2"] = cer_ctc # Intermediate CTC (optional) loss_interctc = 0.0 if self.interctc_weight2 != 0.0 and intermediate_outs is not None: for layer_idx, intermediate_out in intermediate_outs: # we assume intermediate_out has the same length & padding # as those of encoder_out if self.encoder2.overlap_chunk_cls is not None: encoder_out_ctc, encoder_out_lens_ctc = \ self.encoder2.overlap_chunk_cls.remove_chunk( intermediate_out, encoder_out_lens, chunk_outs=None) loss_ic, cer_ic = self._calc_ctc_loss2( encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths ) loss_interctc = loss_interctc + loss_ic # Collect Intermedaite CTC stats stats["loss_interctc_layer{}2".format(layer_idx)] = ( loss_ic.detach() if loss_ic is not None else None ) stats["cer_interctc_layer{}2".format(layer_idx)] = cer_ic loss_interctc = loss_interctc / len(intermediate_outs) # calculate whole encoder loss loss_ctc = ( 1 - self.interctc_weight2 ) * loss_ctc + self.interctc_weight2 * loss_interctc # 2b. Attention decoder branch if self.ctc_weight2 != 1.0: loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_predictor_loss2( encoder_out, encoder_out_lens, text, text_lengths ) # 3. CTC-Att loss definition if self.ctc_weight2 == 0.0: loss = loss_att + loss_pre * self.predictor_weight2 elif self.ctc_weight2 == 1.0: loss = loss_ctc else: loss = self.ctc_weight2 * loss_ctc + ( 1 - self.ctc_weight2) * loss_att + loss_pre * self.predictor_weight2 # Collect Attn branch stats stats["loss_att2"] = loss_att.detach() if loss_att is not None else None stats["acc2"] = acc_att stats["cer2"] = cer_att stats["wer2"] = wer_att stats["loss_pre2"] = loss_pre.detach().cpu() if loss_pre is not None else None loss2 = loss loss = loss1 * self.loss_weight_model1 + loss2 * (1 - self.loss_weight_model1) stats["loss1"] = torch.clone(loss1.detach()) stats["loss2"] = torch.clone(loss2.detach()) stats["loss"] = torch.clone(loss.detach()) # force_gatherable: to-device and to-tensor if scalar for DataParallel if self.length_normalized_loss: batch_size = int((text_lengths + 1).sum()) loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) return loss, stats, weight def collect_feats( self, speech: torch.Tensor, speech_lengths: torch.Tensor, text: torch.Tensor, text_lengths: torch.Tensor, ) -> Dict[str, torch.Tensor]: if self.extract_feats_in_collect_stats: feats, feats_lengths = self._extract_feats(speech, speech_lengths) else: # Generate dummy stats if extract_feats_in_collect_stats is False logging.warning( "Generating dummy stats for feats and feats_lengths, " "because encoder_conf.extract_feats_in_collect_stats is " f"{self.extract_feats_in_collect_stats}" ) feats, feats_lengths = speech, speech_lengths return {"feats": feats, "feats_lengths": feats_lengths} def encode( self, speech: torch.Tensor, speech_lengths: torch.Tensor, ind: int = 0, ) -> Tuple[torch.Tensor, torch.Tensor]: """Frontend + Encoder. Note that this method is used by asr_inference.py Args: speech: (Batch, Length, ...) speech_lengths: (Batch, ) """ with autocast(False): # 1. Extract feats feats, feats_lengths = self._extract_feats(speech, speech_lengths) # 2. Data augmentation if self.specaug is not None and self.training: feats, feats_lengths = self.specaug(feats, feats_lengths) # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN if self.normalize is not None: feats, feats_lengths = self.normalize(feats, feats_lengths) speech_raw = feats.clone().to(feats.device) # Pre-encoder, e.g. used for raw input data if self.preencoder is not None: feats, feats_lengths = self.preencoder(feats, feats_lengths) # 4. Forward encoder # feats: (Batch, Length, Dim) # -> encoder_out: (Batch, Length2, Dim2) if self.encoder.interctc_use_conditioning: encoder_out, encoder_out_lens, _ = self.encoder( feats, feats_lengths, ctc=self.ctc, ind=ind ) else: encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths, ind=ind) intermediate_outs = None if isinstance(encoder_out, tuple): intermediate_outs = encoder_out[1] encoder_out = encoder_out[0] # Post-encoder, e.g. NLU if self.postencoder is not None: encoder_out, encoder_out_lens = self.postencoder( encoder_out, encoder_out_lens ) assert encoder_out.size(0) == speech.size(0), ( encoder_out.size(), speech.size(0), ) assert encoder_out.size(1) <= encoder_out_lens.max(), ( encoder_out.size(), encoder_out_lens.max(), ) if intermediate_outs is not None: return (encoder_out, intermediate_outs), encoder_out_lens return speech_raw, encoder_out, encoder_out_lens def encode2( self, encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, speech: torch.Tensor, speech_lengths: torch.Tensor, ind: int = 0, ) -> Tuple[torch.Tensor, torch.Tensor]: """Frontend + Encoder. Note that this method is used by asr_inference.py Args: speech: (Batch, Length, ...) speech_lengths: (Batch, ) """ # with autocast(False): # # 1. Extract feats # feats, feats_lengths = self._extract_feats(speech, speech_lengths) # # # 2. Data augmentation # if self.specaug is not None and self.training: # feats, feats_lengths = self.specaug(feats, feats_lengths) # # # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN # if self.normalize is not None: # feats, feats_lengths = self.normalize(feats, feats_lengths) # Pre-encoder, e.g. used for raw input data # if self.preencoder is not None: # feats, feats_lengths = self.preencoder(feats, feats_lengths) encoder_out_rm, encoder_out_lens_rm = self.encoder.overlap_chunk_cls.remove_chunk( encoder_out, encoder_out_lens, chunk_outs=None, ) # residual_input encoder_out = torch.cat((speech, encoder_out_rm), dim=-1) encoder_out_lens = encoder_out_lens_rm if self.stride_conv is not None: speech, speech_lengths = self.stride_conv(encoder_out, encoder_out_lens) if not self.encoder1_encoder2_joint_training: speech = speech.detach() speech_lengths = speech_lengths.detach() # 4. Forward encoder # feats: (Batch, Length, Dim) # -> encoder_out: (Batch, Length2, Dim2) if self.encoder2.interctc_use_conditioning: encoder_out, encoder_out_lens, _ = self.encoder2( speech, speech_lengths, ctc=self.ctc2, ind=ind ) else: encoder_out, encoder_out_lens, _ = self.encoder2(speech, speech_lengths, ind=ind) intermediate_outs = None if isinstance(encoder_out, tuple): intermediate_outs = encoder_out[1] encoder_out = encoder_out[0] # # Post-encoder, e.g. NLU # if self.postencoder is not None: # encoder_out, encoder_out_lens = self.postencoder( # encoder_out, encoder_out_lens # ) assert encoder_out.size(0) == speech.size(0), ( encoder_out.size(), speech.size(0), ) assert encoder_out.size(1) <= encoder_out_lens.max(), ( encoder_out.size(), encoder_out_lens.max(), ) if intermediate_outs is not None: return (encoder_out, intermediate_outs), encoder_out_lens return encoder_out, encoder_out_lens def _extract_feats( self, speech: torch.Tensor, speech_lengths: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: assert speech_lengths.dim() == 1, speech_lengths.shape # for data-parallel speech = speech[:, : speech_lengths.max()] if self.frontend is not None: # Frontend # e.g. STFT and Feature extract # data_loader may send time-domain signal in this case # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim) feats, feats_lengths = self.frontend(speech, speech_lengths) else: # No frontend and no feature extract feats, feats_lengths = speech, speech_lengths return feats, feats_lengths def nll( self, encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, ys_pad: torch.Tensor, ys_pad_lens: torch.Tensor, ) -> torch.Tensor: """Compute negative log likelihood(nll) from transformer-decoder Normally, this function is called in batchify_nll. Args: encoder_out: (Batch, Length, Dim) encoder_out_lens: (Batch,) ys_pad: (Batch, Length) ys_pad_lens: (Batch,) """ ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_in_lens = ys_pad_lens + 1 # 1. Forward decoder decoder_out, _ = self.decoder( encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens ) # [batch, seqlen, dim] batch_size = decoder_out.size(0) decoder_num_class = decoder_out.size(2) # nll: negative log-likelihood nll = torch.nn.functional.cross_entropy( decoder_out.view(-1, decoder_num_class), ys_out_pad.view(-1), ignore_index=self.ignore_id, reduction="none", ) nll = nll.view(batch_size, -1) nll = nll.sum(dim=1) assert nll.size(0) == batch_size return nll def batchify_nll( self, encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, ys_pad: torch.Tensor, ys_pad_lens: torch.Tensor, batch_size: int = 100, ): """Compute negative log likelihood(nll) from transformer-decoder To avoid OOM, this fuction seperate the input into batches. Then call nll for each batch and combine and return results. Args: encoder_out: (Batch, Length, Dim) encoder_out_lens: (Batch,) ys_pad: (Batch, Length) ys_pad_lens: (Batch,) batch_size: int, samples each batch contain when computing nll, you may change this to avoid OOM or increase GPU memory usage """ total_num = encoder_out.size(0) if total_num <= batch_size: nll = self.nll(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens) else: nll = [] start_idx = 0 while True: end_idx = min(start_idx + batch_size, total_num) batch_encoder_out = encoder_out[start_idx:end_idx, :, :] batch_encoder_out_lens = encoder_out_lens[start_idx:end_idx] batch_ys_pad = ys_pad[start_idx:end_idx, :] batch_ys_pad_lens = ys_pad_lens[start_idx:end_idx] batch_nll = self.nll( batch_encoder_out, batch_encoder_out_lens, batch_ys_pad, batch_ys_pad_lens, ) nll.append(batch_nll) start_idx = end_idx if start_idx == total_num: break nll = torch.cat(nll) assert nll.size(0) == total_num return nll def _calc_att_loss( self, encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, ys_pad: torch.Tensor, ys_pad_lens: torch.Tensor, ): ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_in_lens = ys_pad_lens + 1 # 1. Forward decoder decoder_out, _ = self.decoder( encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens ) # 2. Compute attention loss loss_att = self.criterion_att(decoder_out, ys_out_pad) acc_att = th_accuracy( decoder_out.view(-1, self.vocab_size), ys_out_pad, ignore_label=self.ignore_id, ) # Compute cer/wer using attention-decoder if self.training or self.error_calculator is None: cer_att, wer_att = None, None else: ys_hat = decoder_out.argmax(dim=-1) cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) return loss_att, acc_att, cer_att, wer_att def _calc_att_predictor_loss( self, encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, ys_pad: torch.Tensor, ys_pad_lens: torch.Tensor, ): ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_in_lens = ys_pad_lens + 1 encoder_out_mask = sequence_mask(encoder_out_lens, maxlen=encoder_out.size(1), dtype=encoder_out.dtype, device=encoder_out.device)[:, None, :] mask_chunk_predictor = None if self.encoder.overlap_chunk_cls is not None: mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(None, device=encoder_out.device, batch_size=encoder_out.size( 0)) mask_shfit_chunk = self.encoder.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device, batch_size=encoder_out.size(0)) encoder_out = encoder_out * mask_shfit_chunk pre_acoustic_embeds, pre_token_length, pre_alphas, _ = self.predictor(encoder_out, ys_out_pad, encoder_out_mask, ignore_id=self.ignore_id, mask_chunk_predictor=mask_chunk_predictor, target_label_length=ys_in_lens, ) predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(pre_alphas, encoder_out_lens) scama_mask = None if self.encoder.overlap_chunk_cls is not None and self.decoder_attention_chunk_type == 'chunk': encoder_chunk_size = self.encoder.overlap_chunk_cls.chunk_size_pad_shift_cur attention_chunk_center_bias = 0 attention_chunk_size = encoder_chunk_size decoder_att_look_back_factor = self.encoder.overlap_chunk_cls.decoder_att_look_back_factor_cur mask_shift_att_chunk_decoder = self.encoder.overlap_chunk_cls.get_mask_shift_att_chunk_decoder(None, device=encoder_out.device, batch_size=encoder_out.size( 0)) scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn( predictor_alignments=predictor_alignments, encoder_sequence_length=encoder_out_lens, chunk_size=1, encoder_chunk_size=encoder_chunk_size, attention_chunk_center_bias=attention_chunk_center_bias, attention_chunk_size=attention_chunk_size, attention_chunk_type=self.decoder_attention_chunk_type, step=None, predictor_mask_chunk_hopping=mask_chunk_predictor, decoder_att_look_back_factor=decoder_att_look_back_factor, mask_shift_att_chunk_decoder=mask_shift_att_chunk_decoder, target_length=ys_in_lens, is_training=self.training, ) elif self.encoder.overlap_chunk_cls is not None: encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out, encoder_out_lens, chunk_outs=None) # try: # 1. Forward decoder decoder_out, _ = self.decoder( encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens, chunk_mask=scama_mask, pre_acoustic_embeds=pre_acoustic_embeds, ) # 2. Compute attention loss loss_att = self.criterion_att(decoder_out, ys_out_pad) acc_att = th_accuracy( decoder_out.view(-1, self.vocab_size), ys_out_pad, ignore_label=self.ignore_id, ) # predictor loss loss_pre = self.criterion_pre(ys_in_lens.type_as(pre_token_length), pre_token_length) # Compute cer/wer using attention-decoder if self.training or self.error_calculator is None: cer_att, wer_att = None, None else: ys_hat = decoder_out.argmax(dim=-1) cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) return loss_att, acc_att, cer_att, wer_att, loss_pre def _calc_att_predictor_loss2( self, encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, ys_pad: torch.Tensor, ys_pad_lens: torch.Tensor, ): ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_in_lens = ys_pad_lens + 1 encoder_out_mask = sequence_mask(encoder_out_lens, maxlen=encoder_out.size(1), dtype=encoder_out.dtype, device=encoder_out.device)[:, None, :] mask_chunk_predictor = None if self.encoder2.overlap_chunk_cls is not None: mask_chunk_predictor = self.encoder2.overlap_chunk_cls.get_mask_chunk_predictor(None, device=encoder_out.device, batch_size=encoder_out.size( 0)) mask_shfit_chunk = self.encoder2.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device, batch_size=encoder_out.size(0)) encoder_out = encoder_out * mask_shfit_chunk pre_acoustic_embeds, pre_token_length, pre_alphas, _ = self.predictor2(encoder_out, ys_out_pad, encoder_out_mask, ignore_id=self.ignore_id, mask_chunk_predictor=mask_chunk_predictor, target_label_length=ys_in_lens, ) predictor_alignments, predictor_alignments_len = self.predictor2.gen_frame_alignments(pre_alphas, encoder_out_lens) scama_mask = None if self.encoder2.overlap_chunk_cls is not None and self.decoder_attention_chunk_type2 == 'chunk': encoder_chunk_size = self.encoder2.overlap_chunk_cls.chunk_size_pad_shift_cur attention_chunk_center_bias = 0 attention_chunk_size = encoder_chunk_size decoder_att_look_back_factor = self.encoder2.overlap_chunk_cls.decoder_att_look_back_factor_cur mask_shift_att_chunk_decoder = self.encoder2.overlap_chunk_cls.get_mask_shift_att_chunk_decoder(None, device=encoder_out.device, batch_size=encoder_out.size( 0)) scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn2( predictor_alignments=predictor_alignments, encoder_sequence_length=encoder_out_lens, chunk_size=1, encoder_chunk_size=encoder_chunk_size, attention_chunk_center_bias=attention_chunk_center_bias, attention_chunk_size=attention_chunk_size, attention_chunk_type=self.decoder_attention_chunk_type2, step=None, predictor_mask_chunk_hopping=mask_chunk_predictor, decoder_att_look_back_factor=decoder_att_look_back_factor, mask_shift_att_chunk_decoder=mask_shift_att_chunk_decoder, target_length=ys_in_lens, is_training=self.training, ) elif self.encoder2.overlap_chunk_cls is not None: encoder_out, encoder_out_lens = self.encoder2.overlap_chunk_cls.remove_chunk(encoder_out, encoder_out_lens, chunk_outs=None) # try: # 1. Forward decoder decoder_out, _ = self.decoder2( encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens, chunk_mask=scama_mask, pre_acoustic_embeds=pre_acoustic_embeds, ) # 2. Compute attention loss loss_att = self.criterion_att(decoder_out, ys_out_pad) acc_att = th_accuracy( decoder_out.view(-1, self.vocab_size), ys_out_pad, ignore_label=self.ignore_id, ) # predictor loss loss_pre = self.criterion_pre(ys_in_lens.type_as(pre_token_length), pre_token_length) # Compute cer/wer using attention-decoder if self.training or self.error_calculator is None: cer_att, wer_att = None, None else: ys_hat = decoder_out.argmax(dim=-1) cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) return loss_att, acc_att, cer_att, wer_att, loss_pre def calc_predictor_mask( self, encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, ys_pad: torch.Tensor = None, ys_pad_lens: torch.Tensor = None, ): # ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) # ys_in_lens = ys_pad_lens + 1 ys_out_pad, ys_in_lens = None, None encoder_out_mask = sequence_mask(encoder_out_lens, maxlen=encoder_out.size(1), dtype=encoder_out.dtype, device=encoder_out.device)[:, None, :] mask_chunk_predictor = None if self.encoder.overlap_chunk_cls is not None: mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(None, device=encoder_out.device, batch_size=encoder_out.size( 0)) mask_shfit_chunk = self.encoder.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device, batch_size=encoder_out.size(0)) encoder_out = encoder_out * mask_shfit_chunk pre_acoustic_embeds, pre_token_length, pre_alphas, _ = self.predictor(encoder_out, ys_out_pad, encoder_out_mask, ignore_id=self.ignore_id, mask_chunk_predictor=mask_chunk_predictor, target_label_length=ys_in_lens, ) predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(pre_alphas, encoder_out_lens) scama_mask = None if self.encoder.overlap_chunk_cls is not None and self.decoder_attention_chunk_type == 'chunk': encoder_chunk_size = self.encoder.overlap_chunk_cls.chunk_size_pad_shift_cur attention_chunk_center_bias = 0 attention_chunk_size = encoder_chunk_size decoder_att_look_back_factor = self.encoder.overlap_chunk_cls.decoder_att_look_back_factor_cur mask_shift_att_chunk_decoder = self.encoder.overlap_chunk_cls.get_mask_shift_att_chunk_decoder(None, device=encoder_out.device, batch_size=encoder_out.size( 0)) scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn( predictor_alignments=predictor_alignments, encoder_sequence_length=encoder_out_lens, chunk_size=1, encoder_chunk_size=encoder_chunk_size, attention_chunk_center_bias=attention_chunk_center_bias, attention_chunk_size=attention_chunk_size, attention_chunk_type=self.decoder_attention_chunk_type, step=None, predictor_mask_chunk_hopping=mask_chunk_predictor, decoder_att_look_back_factor=decoder_att_look_back_factor, mask_shift_att_chunk_decoder=mask_shift_att_chunk_decoder, target_length=ys_in_lens, is_training=self.training, ) elif self.encoder.overlap_chunk_cls is not None: encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out, encoder_out_lens, chunk_outs=None) return pre_acoustic_embeds, pre_token_length, predictor_alignments, predictor_alignments_len, scama_mask def calc_predictor_mask2( self, encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, ys_pad: torch.Tensor = None, ys_pad_lens: torch.Tensor = None, ): # ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) # ys_in_lens = ys_pad_lens + 1 ys_out_pad, ys_in_lens = None, None encoder_out_mask = sequence_mask(encoder_out_lens, maxlen=encoder_out.size(1), dtype=encoder_out.dtype, device=encoder_out.device)[:, None, :] mask_chunk_predictor = None if self.encoder2.overlap_chunk_cls is not None: mask_chunk_predictor = self.encoder2.overlap_chunk_cls.get_mask_chunk_predictor(None, device=encoder_out.device, batch_size=encoder_out.size( 0)) mask_shfit_chunk = self.encoder2.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device, batch_size=encoder_out.size(0)) encoder_out = encoder_out * mask_shfit_chunk pre_acoustic_embeds, pre_token_length, pre_alphas, _ = self.predictor2(encoder_out, ys_out_pad, encoder_out_mask, ignore_id=self.ignore_id, mask_chunk_predictor=mask_chunk_predictor, target_label_length=ys_in_lens, ) predictor_alignments, predictor_alignments_len = self.predictor2.gen_frame_alignments(pre_alphas, encoder_out_lens) scama_mask = None if self.encoder2.overlap_chunk_cls is not None and self.decoder_attention_chunk_type2 == 'chunk': encoder_chunk_size = self.encoder2.overlap_chunk_cls.chunk_size_pad_shift_cur attention_chunk_center_bias = 0 attention_chunk_size = encoder_chunk_size decoder_att_look_back_factor = self.encoder2.overlap_chunk_cls.decoder_att_look_back_factor_cur mask_shift_att_chunk_decoder = self.encoder2.overlap_chunk_cls.get_mask_shift_att_chunk_decoder(None, device=encoder_out.device, batch_size=encoder_out.size( 0)) scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn2( predictor_alignments=predictor_alignments, encoder_sequence_length=encoder_out_lens, chunk_size=1, encoder_chunk_size=encoder_chunk_size, attention_chunk_center_bias=attention_chunk_center_bias, attention_chunk_size=attention_chunk_size, attention_chunk_type=self.decoder_attention_chunk_type2, step=None, predictor_mask_chunk_hopping=mask_chunk_predictor, decoder_att_look_back_factor=decoder_att_look_back_factor, mask_shift_att_chunk_decoder=mask_shift_att_chunk_decoder, target_length=ys_in_lens, is_training=self.training, ) elif self.encoder2.overlap_chunk_cls is not None: encoder_out, encoder_out_lens = self.encoder2.overlap_chunk_cls.remove_chunk(encoder_out, encoder_out_lens, chunk_outs=None) return pre_acoustic_embeds, pre_token_length, predictor_alignments, predictor_alignments_len, scama_mask def _calc_ctc_loss( self, encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, ys_pad: torch.Tensor, ys_pad_lens: torch.Tensor, ): # Calc CTC loss loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens) # Calc CER using CTC cer_ctc = None if not self.training and self.error_calculator is not None: ys_hat = self.ctc.argmax(encoder_out).data cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) return loss_ctc, cer_ctc def _calc_ctc_loss2( self, encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, ys_pad: torch.Tensor, ys_pad_lens: torch.Tensor, ): # Calc CTC loss loss_ctc = self.ctc2(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens) # Calc CER using CTC cer_ctc = None if not self.training and self.error_calculator is not None: ys_hat = self.ctc2.argmax(encoder_out).data cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) return loss_ctc, cer_ctc