| | |
| | | |
| | | # bias encoder |
| | | if self.bias_encoder_type == 'lstm': |
| | | logging.warning("enable bias encoder sampling and contextual training") |
| | | self.bias_encoder = torch.nn.LSTM(self.inner_dim, |
| | | self.inner_dim, |
| | | 2, |
| | |
| | | self.lstm_proj = None |
| | | self.bias_embed = torch.nn.Embedding(self.vocab_size, self.inner_dim) |
| | | elif self.bias_encoder_type == 'mean': |
| | | logging.warning("enable bias encoder sampling and contextual training") |
| | | self.bias_embed = torch.nn.Embedding(self.vocab_size, self.inner_dim) |
| | | else: |
| | | logging.error("Unsupport bias encoder type: {}".format(self.bias_encoder_type)) |
| | |
| | | ys_pad_lens, |
| | | hw_list, |
| | | nfilter=50, |
| | | seaco_weight=1.0): |
| | | seaco_weight=1.0): |
| | | # decoder forward |
| | | decoder_out, decoder_hidden, _ = self.decoder(encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, return_hidden=True, return_both=True) |
| | | decoder_pred = torch.log_softmax(decoder_out, dim=-1) |
| | |
| | | |
| | | dha_output = self.hotword_output_layer(merged) # remove the last token in loss calculation |
| | | dha_pred = torch.log_softmax(dha_output, dim=-1) |
| | | # import pdb; pdb.set_trace() |
| | | def _merge_res(dec_output, dha_output): |
| | | lmbd = torch.Tensor([seaco_weight] * dha_output.shape[0]) |
| | | dha_ids = dha_output.max(-1)[-1][0] |
| | | dha_ids = dha_output.max(-1)[-1]# [0] |
| | | dha_mask = (dha_ids == 8377).int().unsqueeze(-1) |
| | | a = (1 - lmbd) / lmbd |
| | | b = 1 / lmbd |
| | |
| | | logits = dec_output * dha_mask + dha_output[:,:,:] * (1-dha_mask) |
| | | return logits |
| | | merged_pred = _merge_res(decoder_pred, dha_pred) |
| | | # import pdb; pdb.set_trace() |
| | | return merged_pred |
| | | else: |
| | | return decoder_pred |
| | |
| | | token, timestamp) |
| | | |
| | | result_i = {"key": key[i], "text": text_postprocessed, |
| | | "timestamp": time_stamp_postprocessed, |
| | | "timestamp": time_stamp_postprocessed, "raw_text": copy.copy(text_postprocessed) |
| | | } |
| | | |
| | | if ibest_writer is not None: |
| | | ibest_writer["token"][key[i]] = " ".join(token) |
| | | # ibest_writer["text"][key[i]] = text |
| | | # ibest_writer["raw_text"][key[i]] = text |
| | | ibest_writer["timestamp"][key[i]] = time_stamp_postprocessed |
| | | ibest_writer["text"][key[i]] = text_postprocessed |
| | | else: |