| | |
| | | from funasr.models.transformer.utils.add_sos_eos import add_sos_eos |
| | | from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list |
| | | from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank |
| | | |
| | | import pdb |
| | | |
| | | if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): |
| | | from torch.cuda.amp import autocast |
| | |
| | | crit_attn_smooth = kwargs.get("crit_attn_smooth", 0.0) |
| | | bias_encoder_dropout_rate = kwargs.get("bias_encoder_dropout_rate", 0.0) |
| | | |
| | | |
| | | pdb.set_trace() |
| | | if bias_encoder_type == 'lstm': |
| | | self.bias_encoder = torch.nn.LSTM(inner_dim, inner_dim, 1, batch_first=True, dropout=bias_encoder_dropout_rate) |
| | | self.bias_embed = torch.nn.Embedding(self.vocab_size, inner_dim) |
| | |
| | | if self.crit_attn_weight > 0: |
| | | self.attn_loss = torch.nn.L1Loss() |
| | | self.crit_attn_smooth = crit_attn_smooth |
| | | pdb.set_trace() |
| | | |
| | | |
| | | def forward( |
| | |
| | | text_lengths = text_lengths[:, 0] |
| | | if len(speech_lengths.size()) > 1: |
| | | speech_lengths = speech_lengths[:, 0] |
| | | |
| | | pdb.set_trace() |
| | | batch_size = speech.shape[0] |
| | | |
| | | hotword_pad = kwargs.get("hotword_pad") |
| | | hotword_lengths = kwargs.get("hotword_lengths") |
| | | dha_pad = kwargs.get("dha_pad") |
| | | |
| | | pdb.set_trace() |
| | | # 1. Encoder |
| | | encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) |
| | | |
| | | |
| | | pdb.set_trace() |
| | | loss_ctc, cer_ctc = None, None |
| | | |
| | | stats = dict() |
| | |
| | | stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None |
| | | stats["cer_ctc"] = cer_ctc |
| | | |
| | | |
| | | pdb.set_trace() |
| | | # 2b. Attention decoder branch |
| | | loss_att, acc_att, cer_att, wer_att, loss_pre, loss_ideal = self._calc_att_clas_loss( |
| | | encoder_out, encoder_out_lens, text, text_lengths, hotword_pad, hotword_lengths |
| | | ) |
| | | |
| | | pdb.set_trace() |
| | | # 3. CTC-Att loss definition |
| | | if self.ctc_weight == 0.0: |
| | | loss = loss_att + loss_pre * self.predictor_weight |
| | |
| | | ): |
| | | encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to( |
| | | encoder_out.device) |
| | | pdb.set_trace() |
| | | if self.predictor_bias == 1: |
| | | _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) |
| | | ys_pad_lens = ys_pad_lens + self.predictor_bias |
| | | pdb.set_trace() |
| | | pre_acoustic_embeds, pre_token_length, _, _ = self.predictor(encoder_out, ys_pad, encoder_out_mask, |
| | | ignore_id=self.ignore_id) |
| | | |
| | | pdb.set_trace() |
| | | # -1. bias encoder |
| | | if self.use_decoder_embedding: |
| | | hw_embed = self.decoder.embed(hotword_pad) |
| | | else: |
| | | hw_embed = self.bias_embed(hotword_pad) |
| | | pdb.set_trace() |
| | | hw_embed, (_, _) = self.bias_encoder(hw_embed) |
| | | pdb.set_trace() |
| | | _ind = np.arange(0, hotword_pad.shape[0]).tolist() |
| | | selected = hw_embed[_ind, [i - 1 for i in hotword_lengths.detach().cpu().tolist()]] |
| | | contextual_info = selected.squeeze(0).repeat(ys_pad.shape[0], 1, 1).to(ys_pad.device) |
| | | |
| | | pdb.set_trace() |
| | | # 0. sampler |
| | | decoder_out_1st = None |
| | | if self.sampling_ratio > 0.0: |
| | |
| | | if self.step_cur < 2: |
| | | logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio)) |
| | | sematic_embeds = pre_acoustic_embeds |
| | | |
| | | pdb.set_trace() |
| | | # 1. Forward decoder |
| | | decoder_outs = self.decoder( |
| | | encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=contextual_info |
| | |
| | | loss_ideal = None |
| | | ''' |
| | | loss_ideal = None |
| | | |
| | | pdb.set_trace() |
| | | if decoder_out_1st is None: |
| | | decoder_out_1st = decoder_out |
| | | # 2. Compute attention loss |