| | |
| | | text: (Batch, Length) |
| | | text_lengths: (Batch,) |
| | | """ |
| | | text_lengths = text_lengths.squeeze() |
| | | speech_lengths = speech_lengths.squeeze() |
| | | assert text_lengths.dim() == 1, text_lengths.shape |
| | | # Check that batch_size is unified |
| | | assert ( |
| | |
| | | |
| | | hotword_pad = kwargs.get("hotword_pad") |
| | | hotword_lengths = kwargs.get("hotword_lengths") |
| | | dha_pad = kwargs.get("dha_pad") |
| | | seaco_label_pad = kwargs.get("seaco_label_pad") |
| | | |
| | | batch_size = speech.shape[0] |
| | | # for data-parallel |
| | |
| | | ys_lengths, |
| | | hotword_pad, |
| | | hotword_lengths, |
| | | dha_pad, |
| | | seaco_label_pad, |
| | | ) |
| | | if self.train_decoder: |
| | | loss_att, acc_att = self._calc_att_loss( |
| | |
| | | |
| | | # force_gatherable: to-device and to-tensor if scalar for DataParallel |
| | | if self.length_normalized_loss: |
| | | batch_size = (text_lengths + self.predictor_bias).sum().type_as(batch_size) |
| | | batch_size = (text_lengths + self.predictor_bias).sum() |
| | | loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) |
| | | return loss, stats, weight |
| | | |
| | |
| | | ys_lengths: torch.Tensor, |
| | | hotword_pad: torch.Tensor, |
| | | hotword_lengths: torch.Tensor, |
| | | dha_pad: torch.Tensor, |
| | | seaco_label_pad: torch.Tensor, |
| | | ): |
| | | # predictor forward |
| | | encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to( |
| | | encoder_out.device) |
| | | pre_acoustic_embeds, _, _, _ = self.predictor(encoder_out, ys_pad, encoder_out_mask, |
| | | ignore_id=self.ignore_id) |
| | | pre_acoustic_embeds = self.predictor(encoder_out, ys_pad, encoder_out_mask, ignore_id=self.ignore_id)[0] |
| | | # decoder forward |
| | | decoder_out, _ = self.decoder(encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_lengths, return_hidden=True) |
| | | selected = self._hotword_representation(hotword_pad, |
| | |
| | | dec_attended, _ = self.seaco_decoder(contextual_info, _contextual_length, decoder_out, ys_lengths) |
| | | merged = self._merge(cif_attended, dec_attended) |
| | | dha_output = self.hotword_output_layer(merged[:, :-1]) # remove the last token in loss calculation |
| | | loss_att = self.criterion_seaco(dha_output, dha_pad) |
| | | loss_att = self.criterion_seaco(dha_output, seaco_label_pad) |
| | | return loss_att |
| | | |
| | | def _seaco_decode_with_ASF(self, |
| | |
| | | |
| | | return results, meta_data |
| | | |
| | | |
| | | def generate_hotwords_list(self, hotword_list_or_file, tokenizer=None, frontend=None): |
| | | def load_seg_dict(seg_dict_file): |
| | | seg_dict = {} |
| | |
| | | hotword_list = None |
| | | return hotword_list |
| | | |
| | | def export( |
| | | self, |
| | | **kwargs, |
| | | ): |
| | | if 'max_seq_len' not in kwargs: |
| | | kwargs['max_seq_len'] = 512 |
| | | from .export_meta import export_rebuild_model |
| | | models = export_rebuild_model(model=self, **kwargs) |
| | | return models |
| | | |