| | |
| | | import torch |
| | | import random |
| | | import numpy as np |
| | | from typeguard import check_argument_types |
| | | |
| | | from funasr.layers.abs_normalize import AbsNormalize |
| | | from funasr.losses.label_smoothing_loss import ( |
| | |
| | | postencoder: Optional[AbsPostEncoder] = None, |
| | | use_1st_decoder_loss: bool = False, |
| | | ): |
| | | assert check_argument_types() |
| | | assert 0.0 <= ctc_weight <= 1.0, ctc_weight |
| | | assert 0.0 <= interctc_weight < 1.0, interctc_weight |
| | | |
| | |
| | | loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight |
| | | |
| | | if self.use_1st_decoder_loss and pre_loss_att is not None: |
| | | loss = loss + pre_loss_att |
| | | loss = loss + (1 - self.ctc_weight) * pre_loss_att |
| | | |
| | | # Collect Attn branch stats |
| | | stats["loss_att"] = loss_att.detach() if loss_att is not None else None |
| | |
| | | postencoder: Optional[AbsPostEncoder] = None, |
| | | use_1st_decoder_loss: bool = False, |
| | | ): |
| | | assert check_argument_types() |
| | | assert 0.0 <= ctc_weight <= 1.0, ctc_weight |
| | | assert 0.0 <= interctc_weight < 1.0, interctc_weight |
| | | |
| | |
| | | mask_chunk_predictor=mask_chunk_predictor, |
| | | target_label_length=None, |
| | | ) |
| | | predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(pre_alphas[:, :-1], |
| | | encoder_out_lens) |
| | | predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(pre_alphas, |
| | | encoder_out_lens+1 if self.predictor.tail_threshold > 0.0 else encoder_out_lens) |
| | | |
| | | scama_mask = None |
| | | if self.encoder.overlap_chunk_cls is not None and self.decoder_attention_chunk_type == 'chunk': |
| | |
| | | preencoder: Optional[AbsPreEncoder] = None, |
| | | postencoder: Optional[AbsPostEncoder] = None, |
| | | ): |
| | | assert check_argument_types() |
| | | assert 0.0 <= ctc_weight <= 1.0, ctc_weight |
| | | assert 0.0 <= interctc_weight < 1.0, interctc_weight |
| | | |
| | |
| | | preencoder: Optional[AbsPreEncoder] = None, |
| | | postencoder: Optional[AbsPostEncoder] = None, |
| | | ): |
| | | assert check_argument_types() |
| | | assert 0.0 <= ctc_weight <= 1.0, ctc_weight |
| | | assert 0.0 <= interctc_weight < 1.0, interctc_weight |
| | | |
| | |
| | | preencoder: Optional[AbsPreEncoder] = None, |
| | | postencoder: Optional[AbsPostEncoder] = None, |
| | | ): |
| | | assert check_argument_types() |
| | | assert 0.0 <= ctc_weight <= 1.0, ctc_weight |
| | | assert 0.0 <= interctc_weight < 1.0, interctc_weight |
| | | |
| | |
| | | |
| | | return loss_att, acc_att, cer_att, wer_att, loss_pre |
| | | |
| | | def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, hw_list=None): |
| | | def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, hw_list=None, clas_scale=1.0): |
| | | if hw_list is None: |
| | | # default hotword list |
| | | hw_list = [torch.Tensor([self.sos]).long().to(encoder_out.device)] # empty hotword list |