| | |
| | | import numpy as np |
| | | |
| | | import torch |
| | | from typeguard import check_argument_types |
| | | |
| | | from funasr.layers.abs_normalize import AbsNormalize |
| | | from funasr.models.ctc import CTC |
| | |
| | | frontend: Optional[AbsFrontend], |
| | | specaug: Optional[AbsSpecAug], |
| | | normalize: Optional[AbsNormalize], |
| | | preencoder: Optional[AbsPreEncoder], |
| | | encoder: AbsEncoder, |
| | | postencoder: Optional[AbsPostEncoder], |
| | | decoder: AbsDecoder, |
| | | ctc: CTC, |
| | | ctc_weight: float = 0.5, |
| | |
| | | target_buffer_length: int = -1, |
| | | inner_dim: int = 256, |
| | | bias_encoder_type: str = 'lstm', |
| | | use_decoder_embedding: bool = True, |
| | | use_decoder_embedding: bool = False, |
| | | crit_attn_weight: float = 0.0, |
| | | crit_attn_smooth: float = 0.0, |
| | | bias_encoder_dropout_rate: float = 0.0, |
| | | preencoder: Optional[AbsPreEncoder] = None, |
| | | postencoder: Optional[AbsPostEncoder] = None, |
| | | ): |
| | | assert check_argument_types() |
| | | assert 0.0 <= ctc_weight <= 1.0, ctc_weight |
| | | assert 0.0 <= interctc_weight < 1.0, interctc_weight |
| | | |
| | |
| | | |
| | | # 1. Forward decoder |
| | | decoder_outs = self.decoder( |
| | | encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=contextual_info, ret_attn=(ideal_attn is not None) |
| | | encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=contextual_info |
| | | ) |
| | | decoder_out, _, attn = decoder_outs[0], decoder_outs[1], decoder_outs[2] |
| | | decoder_out, _ = decoder_outs[0], decoder_outs[1] |
| | | ''' |
| | | if self.crit_attn_weight > 0 and attn.shape[-1] > 1: |
| | | ideal_attn = ideal_attn + self.crit_attn_smooth / (self.crit_attn_smooth + 1.0) |
| | | attn_non_blank = attn[:,:,:,:-1] |
| | |
| | | loss_ideal = self.attn_loss(attn_non_blank.max(1)[0], ideal_attn_non_blank.to(attn.device)) |
| | | else: |
| | | loss_ideal = None |
| | | ''' |
| | | loss_ideal = None |
| | | |
| | | if decoder_out_1st is None: |
| | | decoder_out_1st = decoder_out |
| | |
| | | input_mask_expand_dim, 0) |
| | | return sematic_embeds * tgt_mask, decoder_out * tgt_mask |
| | | |
| | | def cal_decoder_with_predictor_with_hwlist_advanced(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, hw_list=None): |
| | | def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, hw_list=None, clas_scale=1.0): |
| | | if hw_list is None: |
| | | hw_list = [torch.Tensor([1]).long().to(encoder_out.device)] # empty hotword list |
| | | hw_list_pad = pad_list(hw_list, 0) |
| | |
| | | else: |
| | | hw_embed = self.bias_embed(hw_list_pad) |
| | | hw_embed, (h_n, _) = self.bias_encoder(hw_embed) |
| | | hw_embed = h_n.repeat(encoder_out.shape[0], 1, 1) |
| | | else: |
| | | # hw_list = hw_list[1:] + [hw_list[0]] # reorder |
| | | hw_lengths = [len(i) for i in hw_list] |
| | | hw_list_pad = pad_list([torch.Tensor(i).long() for i in hw_list], 0).to(encoder_out.device) |
| | | if self.use_decoder_embedding: |
| | |
| | | hw_embed = torch.nn.utils.rnn.pack_padded_sequence(hw_embed, hw_lengths, batch_first=True, |
| | | enforce_sorted=False) |
| | | _, (h_n, _) = self.bias_encoder(hw_embed) |
| | | # hw_embed, _ = torch.nn.utils.rnn.pad_packed_sequence(hw_embed, batch_first=True) |
| | | if h_n.shape[1] > 2000: # large hotword list |
| | | _h_n = self.pick_hwlist_group(h_n.squeeze(0), encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens) |
| | | if _h_n is not None: |
| | | h_n = _h_n |
| | | hw_embed = h_n.repeat(encoder_out.shape[0], 1, 1) |
| | | # import pdb; pdb.set_trace() |
| | | |
| | | decoder_outs = self.decoder( |
| | | encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=hw_embed |
| | | encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=hw_embed, clas_scale=clas_scale |
| | | ) |
| | | decoder_out = decoder_outs[0] |
| | | decoder_out = torch.log_softmax(decoder_out, dim=-1) |
| | | return decoder_out, ys_pad_lens |
| | | |
| | | def pick_hwlist_group(self, hw_embed, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens): |
| | | max_attn_score = 0.0 |
| | | # max_attn_index = 0 |
| | | argmax_g = None |
| | | non_blank = hw_embed[-1] |
| | | hw_embed_groups = hw_embed[:-1].split(2000) |
| | | for i, g in enumerate(hw_embed_groups): |
| | | g = torch.cat([g, non_blank.unsqueeze(0)], dim=0) |
| | | _ = self.decoder( |
| | | encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=g.unsqueeze(0) |
| | | ) |
| | | attn = self.decoder.bias_decoder.src_attn.attn[0] |
| | | _max_attn_score = attn.max(0)[0][:,:-1].max() |
| | | if _max_attn_score > max_attn_score: |
| | | max_attn_score = _max_attn_score |
| | | # max_attn_index = i |
| | | argmax_g = g |
| | | # import pdb; pdb.set_trace() |
| | | return argmax_g |