| | |
| | | import torch.nn.functional as F |
| | | |
| | | from funasr.layers.abs_normalize import AbsNormalize |
| | | from funasr.losses.label_smoothing_loss import ( |
| | | LabelSmoothingLoss, NllLoss # noqa: H301 |
| | | ) |
| | | from funasr.losses.label_smoothing_loss import LabelSmoothingLoss, NllLoss # noqa: H301 |
| | | from funasr.models.ctc import CTC |
| | | from funasr.models.decoder.abs_decoder import AbsDecoder |
| | | from funasr.models.encoder.abs_encoder import AbsEncoder |
| | | from funasr.models.frontend.abs_frontend import AbsFrontend |
| | | from funasr.frontends.abs_frontend import AbsFrontend |
| | | from funasr.models.postencoder.abs_postencoder import AbsPostEncoder |
| | | from funasr.models.preencoder.abs_preencoder import AbsPreEncoder |
| | | from funasr.models.specaug.abs_specaug import AbsSpecAug |
| | | from funasr.models.transformer.add_sos_eos import add_sos_eos |
| | | from funasr.models.transformer.utils.add_sos_eos import add_sos_eos |
| | | from funasr.metrics import ErrorCalculator |
| | | from funasr.models.transformer.utils.nets_utils import th_accuracy |
| | | from funasr.metrics.compute_acc import th_accuracy |
| | | from funasr.train_utils.device_funcs import force_gatherable |
| | | from funasr.models.base_model import FunASRModel |
| | | |
| | |
| | | """CTC-attention hybrid Encoder-Decoder model""" |
| | | |
| | | def __init__( |
| | | self, |
| | | vocab_size: int, |
| | | max_spk_num: int, |
| | | token_list: Union[Tuple[str, ...], List[str]], |
| | | frontend: Optional[AbsFrontend], |
| | | specaug: Optional[AbsSpecAug], |
| | | normalize: Optional[AbsNormalize], |
| | | asr_encoder: AbsEncoder, |
| | | spk_encoder: torch.nn.Module, |
| | | decoder: AbsDecoder, |
| | | ctc: CTC, |
| | | spk_weight: float = 0.5, |
| | | ctc_weight: float = 0.5, |
| | | interctc_weight: float = 0.0, |
| | | ignore_id: int = -1, |
| | | lsm_weight: float = 0.0, |
| | | length_normalized_loss: bool = False, |
| | | report_cer: bool = True, |
| | | report_wer: bool = True, |
| | | sym_space: str = "<space>", |
| | | sym_blank: str = "<blank>", |
| | | extract_feats_in_collect_stats: bool = True, |
| | | self, |
| | | vocab_size: int, |
| | | max_spk_num: int, |
| | | token_list: Union[Tuple[str, ...], List[str]], |
| | | frontend: Optional[AbsFrontend], |
| | | specaug: Optional[AbsSpecAug], |
| | | normalize: Optional[AbsNormalize], |
| | | asr_encoder: AbsEncoder, |
| | | spk_encoder: torch.nn.Module, |
| | | decoder: AbsDecoder, |
| | | ctc: CTC, |
| | | spk_weight: float = 0.5, |
| | | ctc_weight: float = 0.5, |
| | | interctc_weight: float = 0.0, |
| | | ignore_id: int = -1, |
| | | lsm_weight: float = 0.0, |
| | | length_normalized_loss: bool = False, |
| | | report_cer: bool = True, |
| | | report_wer: bool = True, |
| | | sym_space: str = "<space>", |
| | | sym_blank: str = "<blank>", |
| | | extract_feats_in_collect_stats: bool = True, |
| | | ): |
| | | assert 0.0 <= ctc_weight <= 1.0, ctc_weight |
| | | assert 0.0 <= interctc_weight < 1.0, interctc_weight |
| | |
| | | self.sos = 1 |
| | | self.eos = 2 |
| | | self.vocab_size = vocab_size |
| | | self.max_spk_num=max_spk_num |
| | | self.max_spk_num = max_spk_num |
| | | self.ignore_id = ignore_id |
| | | self.spk_weight = spk_weight |
| | | self.ctc_weight = ctc_weight |
| | |
| | | ) |
| | | |
| | | self.error_calculator = None |
| | | |
| | | |
| | | # we set self.decoder = None in the CTC mode since |
| | | # self.decoder parameters were never used and PyTorch complained |
| | |
| | | self.extract_feats_in_collect_stats = extract_feats_in_collect_stats |
| | | |
| | | def forward( |
| | | self, |
| | | speech: torch.Tensor, |
| | | speech_lengths: torch.Tensor, |
| | | text: torch.Tensor, |
| | | text_lengths: torch.Tensor, |
| | | profile: torch.Tensor, |
| | | profile_lengths: torch.Tensor, |
| | | text_id: torch.Tensor, |
| | | text_id_lengths: torch.Tensor |
| | | self, |
| | | speech: torch.Tensor, |
| | | speech_lengths: torch.Tensor, |
| | | text: torch.Tensor, |
| | | text_lengths: torch.Tensor, |
| | | profile: torch.Tensor, |
| | | profile_lengths: torch.Tensor, |
| | | text_id: torch.Tensor, |
| | | text_id_lengths: torch.Tensor, |
| | | ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: |
| | | """Frontend + Encoder + Decoder + Calc loss |
| | | |
| | |
| | | assert text_lengths.dim() == 1, text_lengths.shape |
| | | # Check that batch_size is unified |
| | | assert ( |
| | | speech.shape[0] |
| | | == speech_lengths.shape[0] |
| | | == text.shape[0] |
| | | == text_lengths.shape[0] |
| | | speech.shape[0] == speech_lengths.shape[0] == text.shape[0] == text_lengths.shape[0] |
| | | ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape) |
| | | batch_size = speech.shape[0] |
| | | |
| | |
| | | asr_encoder_out, encoder_out_lens, text, text_lengths |
| | | ) |
| | | |
| | | |
| | | # Intermediate CTC (optional) |
| | | loss_interctc = 0.0 |
| | | if self.interctc_weight != 0.0 and intermediate_outs is not None: |
| | |
| | | loss_interctc = loss_interctc / len(intermediate_outs) |
| | | |
| | | # calculate whole encoder loss |
| | | loss_ctc = ( |
| | | 1 - self.interctc_weight |
| | | ) * loss_ctc + self.interctc_weight * loss_interctc |
| | | |
| | | loss_ctc = (1 - self.interctc_weight) * loss_ctc + self.interctc_weight * loss_interctc |
| | | |
| | | # 2b. Attention decoder branch |
| | | if self.ctc_weight != 1.0: |
| | | loss_att, loss_spk, acc_att, acc_spk, cer_att, wer_att = self._calc_att_loss( |
| | | asr_encoder_out, spk_encoder_out, encoder_out_lens, text, text_lengths, profile, profile_lengths, text_id, text_id_lengths |
| | | asr_encoder_out, |
| | | spk_encoder_out, |
| | | encoder_out_lens, |
| | | text, |
| | | text_lengths, |
| | | profile, |
| | | profile_lengths, |
| | | text_id, |
| | | text_id_lengths, |
| | | ) |
| | | |
| | | # 3. CTC-Att loss definition |
| | |
| | | loss = loss_asr |
| | | else: |
| | | loss = self.spk_weight * loss_spk + (1 - self.spk_weight) * loss_asr |
| | | |
| | | |
| | | stats = dict( |
| | | loss=loss.detach(), |
| | |
| | | return loss, stats, weight |
| | | |
| | | def collect_feats( |
| | | self, |
| | | speech: torch.Tensor, |
| | | speech_lengths: torch.Tensor, |
| | | text: torch.Tensor, |
| | | text_lengths: torch.Tensor, |
| | | self, |
| | | speech: torch.Tensor, |
| | | speech_lengths: torch.Tensor, |
| | | text: torch.Tensor, |
| | | text_lengths: torch.Tensor, |
| | | ) -> Dict[str, torch.Tensor]: |
| | | if self.extract_feats_in_collect_stats: |
| | | feats, feats_lengths = self._extract_feats(speech, speech_lengths) |
| | |
| | | return {"feats": feats, "feats_lengths": feats_lengths} |
| | | |
| | | def encode( |
| | | self, speech: torch.Tensor, speech_lengths: torch.Tensor |
| | | self, speech: torch.Tensor, speech_lengths: torch.Tensor |
| | | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | """Frontend + Encoder. Note that this method is used by asr_inference.py |
| | | |
| | |
| | | # feats: (Batch, Length, Dim) |
| | | # -> encoder_out: (Batch, Length2, Dim2) |
| | | if self.asr_encoder.interctc_use_conditioning: |
| | | encoder_out, encoder_out_lens, _ = self.asr_encoder( |
| | | feats, feats_lengths, ctc=self.ctc |
| | | ) |
| | | encoder_out, encoder_out_lens, _ = self.asr_encoder(feats, feats_lengths, ctc=self.ctc) |
| | | else: |
| | | encoder_out, encoder_out_lens, _ = self.asr_encoder(feats, feats_lengths) |
| | | intermediate_outs = None |
| | |
| | | |
| | | encoder_out_spk_ori = self.spk_encoder(feats_raw, feats_lengths)[0] |
| | | # import ipdb;ipdb.set_trace() |
| | | if encoder_out_spk_ori.size(1)!=encoder_out.size(1): |
| | | encoder_out_spk=F.interpolate(encoder_out_spk_ori.transpose(-2,-1), size=(encoder_out.size(1)), mode='nearest').transpose(-2,-1) |
| | | if encoder_out_spk_ori.size(1) != encoder_out.size(1): |
| | | encoder_out_spk = F.interpolate( |
| | | encoder_out_spk_ori.transpose(-2, -1), size=(encoder_out.size(1)), mode="nearest" |
| | | ).transpose(-2, -1) |
| | | else: |
| | | encoder_out_spk=encoder_out_spk_ori |
| | | encoder_out_spk = encoder_out_spk_ori |
| | | |
| | | assert encoder_out.size(0) == speech.size(0), ( |
| | | encoder_out.size(), |
| | |
| | | return encoder_out, encoder_out_lens, encoder_out_spk |
| | | |
| | | def _extract_feats( |
| | | self, speech: torch.Tensor, speech_lengths: torch.Tensor |
| | | self, speech: torch.Tensor, speech_lengths: torch.Tensor |
| | | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | assert speech_lengths.dim() == 1, speech_lengths.shape |
| | | |
| | |
| | | return feats, feats_lengths |
| | | |
| | | def nll( |
| | | self, |
| | | encoder_out: torch.Tensor, |
| | | encoder_out_lens: torch.Tensor, |
| | | ys_pad: torch.Tensor, |
| | | ys_pad_lens: torch.Tensor, |
| | | self, |
| | | encoder_out: torch.Tensor, |
| | | encoder_out_lens: torch.Tensor, |
| | | ys_pad: torch.Tensor, |
| | | ys_pad_lens: torch.Tensor, |
| | | ) -> torch.Tensor: |
| | | """Compute negative log likelihood(nll) from transformer-decoder |
| | | |
| | |
| | | return nll |
| | | |
| | | def batchify_nll( |
| | | self, |
| | | encoder_out: torch.Tensor, |
| | | encoder_out_lens: torch.Tensor, |
| | | ys_pad: torch.Tensor, |
| | | ys_pad_lens: torch.Tensor, |
| | | batch_size: int = 100, |
| | | self, |
| | | encoder_out: torch.Tensor, |
| | | encoder_out_lens: torch.Tensor, |
| | | ys_pad: torch.Tensor, |
| | | ys_pad_lens: torch.Tensor, |
| | | batch_size: int = 100, |
| | | ): |
| | | """Compute negative log likelihood(nll) from transformer-decoder |
| | | |
| | |
| | | return nll |
| | | |
| | | def _calc_att_loss( |
| | | self, |
| | | asr_encoder_out: torch.Tensor, |
| | | spk_encoder_out: torch.Tensor, |
| | | encoder_out_lens: torch.Tensor, |
| | | ys_pad: torch.Tensor, |
| | | ys_pad_lens: torch.Tensor, |
| | | profile: torch.Tensor, |
| | | profile_lens: torch.Tensor, |
| | | text_id: torch.Tensor, |
| | | text_id_lengths: torch.Tensor |
| | | self, |
| | | asr_encoder_out: torch.Tensor, |
| | | spk_encoder_out: torch.Tensor, |
| | | encoder_out_lens: torch.Tensor, |
| | | ys_pad: torch.Tensor, |
| | | ys_pad_lens: torch.Tensor, |
| | | profile: torch.Tensor, |
| | | profile_lens: torch.Tensor, |
| | | text_id: torch.Tensor, |
| | | text_id_lengths: torch.Tensor, |
| | | ): |
| | | ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) |
| | | ys_in_lens = ys_pad_lens + 1 |
| | | |
| | | # 1. Forward decoder |
| | | decoder_out, weights_no_pad, _ = self.decoder( |
| | | asr_encoder_out, spk_encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens, profile, profile_lens |
| | | asr_encoder_out, |
| | | spk_encoder_out, |
| | | encoder_out_lens, |
| | | ys_in_pad, |
| | | ys_in_lens, |
| | | profile, |
| | | profile_lens, |
| | | ) |
| | | |
| | | spk_num_no_pad=weights_no_pad.size(-1) |
| | | pad=(0,self.max_spk_num-spk_num_no_pad) |
| | | weights=F.pad(weights_no_pad, pad, mode='constant', value=0) |
| | | spk_num_no_pad = weights_no_pad.size(-1) |
| | | pad = (0, self.max_spk_num - spk_num_no_pad) |
| | | weights = F.pad(weights_no_pad, pad, mode="constant", value=0) |
| | | |
| | | # pre_id=weights.argmax(-1) |
| | | # pre_text=decoder_out.argmax(-1) |
| | |
| | | loss_att = self.criterion_att(decoder_out, ys_out_pad) |
| | | loss_spk = self.criterion_spk(torch.log(weights), text_id) |
| | | |
| | | acc_spk= th_accuracy( |
| | | acc_spk = th_accuracy( |
| | | weights.view(-1, self.max_spk_num), |
| | | text_id, |
| | | ignore_label=self.ignore_id, |
| | |
| | | return loss_att, loss_spk, acc_att, acc_spk, cer_att, wer_att |
| | | |
| | | def _calc_ctc_loss( |
| | | self, |
| | | encoder_out: torch.Tensor, |
| | | encoder_out_lens: torch.Tensor, |
| | | ys_pad: torch.Tensor, |
| | | ys_pad_lens: torch.Tensor, |
| | | self, |
| | | encoder_out: torch.Tensor, |
| | | encoder_out_lens: torch.Tensor, |
| | | ys_pad: torch.Tensor, |
| | | ys_pad_lens: torch.Tensor, |
| | | ): |
| | | # Calc CTC loss |
| | | loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens) |