| | |
| | | |
| | | return encoder_out, encoder_out_lens |
| | | |
| | | def encode_chunk( |
| | | self, speech: torch.Tensor, speech_lengths: torch.Tensor, cache: dict = None |
| | | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | """Frontend + Encoder. Note that this method is used by asr_inference.py |
| | | |
| | | Args: |
| | | speech: (Batch, Length, ...) |
| | | speech_lengths: (Batch, ) |
| | | """ |
| | | with autocast(False): |
| | | # 1. Extract feats |
| | | feats, feats_lengths = self._extract_feats(speech, speech_lengths) |
| | | |
| | | # 2. Data augmentation |
| | | if self.specaug is not None and self.training: |
| | | feats, feats_lengths = self.specaug(feats, feats_lengths) |
| | | |
| | | # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN |
| | | if self.normalize is not None: |
| | | feats, feats_lengths = self.normalize(feats, feats_lengths) |
| | | |
| | | # Pre-encoder, e.g. used for raw input data |
| | | if self.preencoder is not None: |
| | | feats, feats_lengths = self.preencoder(feats, feats_lengths) |
| | | |
| | | # 4. Forward encoder |
| | | # feats: (Batch, Length, Dim) |
| | | # -> encoder_out: (Batch, Length2, Dim2) |
| | | if self.encoder.interctc_use_conditioning: |
| | | encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk( |
| | | feats, feats_lengths, cache=cache["encoder"], ctc=self.ctc |
| | | ) |
| | | else: |
| | | encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(feats, feats_lengths, cache=cache["encoder"]) |
| | | intermediate_outs = None |
| | | if isinstance(encoder_out, tuple): |
| | | intermediate_outs = encoder_out[1] |
| | | encoder_out = encoder_out[0] |
| | | |
| | | # Post-encoder, e.g. NLU |
| | | if self.postencoder is not None: |
| | | encoder_out, encoder_out_lens = self.postencoder( |
| | | encoder_out, encoder_out_lens |
| | | ) |
| | | |
| | | if intermediate_outs is not None: |
| | | return (encoder_out, intermediate_outs), encoder_out_lens |
| | | |
| | | return encoder_out, torch.tensor([encoder_out.size(1)]) |
| | | |
| | | def calc_predictor(self, encoder_out, encoder_out_lens): |
| | | |
| | | encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to( |
| | | encoder_out.device) |
| | | pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(encoder_out, None, encoder_out_mask, |
| | | ignore_id=self.ignore_id) |
| | | return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index |
| | | |
| | | def calc_predictor_chunk(self, encoder_out, cache=None): |
| | | |
| | | pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor.forward_chunk(encoder_out, cache["encoder"]) |
| | | return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index |
| | | |
| | | def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens): |
| | |
| | | decoder_out = decoder_outs[0] |
| | | decoder_out = torch.log_softmax(decoder_out, dim=-1) |
| | | return decoder_out, ys_pad_lens |
| | | |
| | | def cal_decoder_with_predictor_chunk(self, encoder_out, sematic_embeds, cache=None): |
| | | decoder_outs = self.decoder.forward_chunk( |
| | | encoder_out, sematic_embeds, cache["decoder"] |
| | | ) |
| | | decoder_out = decoder_outs |
| | | decoder_out = torch.log_softmax(decoder_out, dim=-1) |
| | | return decoder_out |
| | | |
| | | def _extract_feats( |
| | | self, speech: torch.Tensor, speech_lengths: torch.Tensor |
| | |
| | | |
| | | # 1. Encoder |
| | | encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) |
| | | intermediate_outs = None |
| | | if isinstance(encoder_out, tuple): |
| | | intermediate_outs = encoder_out[1] |
| | | encoder_out = encoder_out[0] |
| | | |
| | | loss_att, acc_att, cer_att, wer_att = None, None, None, None |
| | | loss_ctc, cer_ctc = None, None |
| | | loss_pre = None |
| | | stats = dict() |
| | | |
| | | # 1. CTC branch |
| | | if self.ctc_weight != 0.0: |
| | | loss_ctc, cer_ctc = self._calc_ctc_loss( |
| | | encoder_out, encoder_out_lens, text, text_lengths |
| | | ) |
| | | |
| | | # Collect CTC branch stats |
| | | stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None |
| | | stats["cer_ctc"] = cer_ctc |
| | | |
| | | # Intermediate CTC (optional) |
| | | loss_interctc = 0.0 |
| | | if self.interctc_weight != 0.0 and intermediate_outs is not None: |
| | | for layer_idx, intermediate_out in intermediate_outs: |
| | | # we assume intermediate_out has the same length & padding |
| | | # as those of encoder_out |
| | | loss_ic, cer_ic = self._calc_ctc_loss( |
| | | intermediate_out, encoder_out_lens, text, text_lengths |
| | | ) |
| | | loss_interctc = loss_interctc + loss_ic |
| | | |
| | | # Collect Intermedaite CTC stats |
| | | stats["loss_interctc_layer{}".format(layer_idx)] = ( |
| | | loss_ic.detach() if loss_ic is not None else None |
| | | ) |
| | | stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic |
| | | |
| | | loss_interctc = loss_interctc / len(intermediate_outs) |
| | | |
| | | # calculate whole encoder loss |
| | | loss_ctc = ( |
| | | 1 - self.interctc_weight |
| | | ) * loss_ctc + self.interctc_weight * loss_interctc |
| | | |
| | | # 2b. Attention decoder branch |
| | | if self.ctc_weight != 1.0: |
| | | loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_loss( |
| | | encoder_out, encoder_out_lens, text, text_lengths |
| | | ) |
| | | |
| | | loss_pre2 = self._calc_pre2_loss( |
| | | encoder_out, encoder_out_lens, text, text_lengths |
| | | ) |
| | | |
| | | loss = loss_pre2 |
| | | # 3. CTC-Att loss definition |
| | | if self.ctc_weight == 0.0: |
| | | loss = loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight * 0.5 |
| | | elif self.ctc_weight == 1.0: |
| | | loss = loss_ctc |
| | | else: |
| | | loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight * 0.5 |
| | | |
| | | # Collect Attn branch stats |
| | | stats["loss_att"] = loss_att.detach() if loss_att is not None else None |
| | | stats["acc"] = acc_att |
| | | stats["cer"] = cer_att |
| | | stats["wer"] = wer_att |
| | | stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None |
| | | stats["loss_pre2"] = loss_pre2.detach().cpu() |
| | | |
| | | stats["loss"] = torch.clone(loss.detach()) |
| | | |
| | | # force_gatherable: to-device and to-tensor if scalar for DataParallel |
| | |
| | | inner_dim: int = 256, |
| | | bias_encoder_type: str = 'lstm', |
| | | label_bracket: bool = False, |
| | | use_decoder_embedding: bool = False, |
| | | ): |
| | | assert check_argument_types() |
| | | assert 0.0 <= ctc_weight <= 1.0, ctc_weight |
| | |
| | | self.hotword_buffer = None |
| | | self.length_record = [] |
| | | self.current_buffer_length = 0 |
| | | self.use_decoder_embedding = use_decoder_embedding |
| | | |
| | | def forward( |
| | | self, |
| | |
| | | hw_list.append(hw_tokens) |
| | | # padding |
| | | hw_list_pad = pad_list(hw_list, 0) |
| | | hw_embed = self.decoder.embed(hw_list_pad) |
| | | if self.use_decoder_embedding: |
| | | hw_embed = self.decoder.embed(hw_list_pad) |
| | | else: |
| | | hw_embed = self.bias_embed(hw_list_pad) |
| | | hw_embed, (_, _) = self.bias_encoder(hw_embed) |
| | | _ind = np.arange(0, len(hw_list)).tolist() |
| | | # update self.hotword_buffer, throw a part if oversize |
| | |
| | | # default hotword list |
| | | hw_list = [torch.Tensor([self.sos]).long().to(encoder_out.device)] # empty hotword list |
| | | hw_list_pad = pad_list(hw_list, 0) |
| | | hw_embed = self.bias_embed(hw_list_pad) |
| | | if self.use_decoder_embedding: |
| | | hw_embed = self.decoder.embed(hw_list_pad) |
| | | else: |
| | | hw_embed = self.bias_embed(hw_list_pad) |
| | | _, (h_n, _) = self.bias_encoder(hw_embed) |
| | | contextual_info = h_n.squeeze(0).repeat(encoder_out.shape[0], 1, 1) |
| | | else: |
| | | hw_lengths = [len(i) for i in hw_list] |
| | | hw_list_pad = pad_list([torch.Tensor(i).long() for i in hw_list], 0).to(encoder_out.device) |
| | | hw_embed = self.bias_embed(hw_list_pad) |
| | | if self.use_decoder_embedding: |
| | | hw_embed = self.decoder.embed(hw_list_pad) |
| | | else: |
| | | hw_embed = self.bias_embed(hw_list_pad) |
| | | hw_embed = torch.nn.utils.rnn.pack_padded_sequence(hw_embed, hw_lengths, batch_first=True, |
| | | enforce_sorted=False) |
| | | _, (h_n, _) = self.bias_encoder(hw_embed) |
| | |
| | | "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf, |
| | | var_dict_tf[name_tf].shape)) |
| | | |
| | | return var_dict_torch_update |
| | | return var_dict_torch_update |