| | |
| | | ys_pad_lens, |
| | | hw_list, |
| | | nfilter=50, |
| | | seaco_weight=1.0): |
| | | seaco_weight=1.0): |
| | | # decoder forward |
| | | decoder_out, decoder_hidden, _ = self.decoder(encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, return_hidden=True, return_both=True) |
| | | decoder_pred = torch.log_softmax(decoder_out, dim=-1) |
| | |
| | | |
| | | dha_output = self.hotword_output_layer(merged) # remove the last token in loss calculation |
| | | dha_pred = torch.log_softmax(dha_output, dim=-1) |
| | | # import pdb; pdb.set_trace() |
| | | def _merge_res(dec_output, dha_output): |
| | | lmbd = torch.Tensor([seaco_weight] * dha_output.shape[0]) |
| | | dha_ids = dha_output.max(-1)[-1][0] |
| | | dha_ids = dha_output.max(-1)[-1]# [0] |
| | | dha_mask = (dha_ids == 8377).int().unsqueeze(-1) |
| | | a = (1 - lmbd) / lmbd |
| | | b = 1 / lmbd |
| | |
| | | logits = dec_output * dha_mask + dha_output[:,:,:] * (1-dha_mask) |
| | | return logits |
| | | merged_pred = _merge_res(decoder_pred, dha_pred) |
| | | # import pdb; pdb.set_trace() |
| | | return merged_pred |
| | | else: |
| | | return decoder_pred |
| | |
| | | token, timestamp) |
| | | |
| | | result_i = {"key": key[i], "text": text_postprocessed, |
| | | "timestamp": time_stamp_postprocessed, |
| | | "timestamp": time_stamp_postprocessed, "raw_text": copy.copy(text_postprocessed) |
| | | } |
| | | |
| | | if ibest_writer is not None: |
| | | ibest_writer["token"][key[i]] = " ".join(token) |
| | | # ibest_writer["text"][key[i]] = text |
| | | # ibest_writer["raw_text"][key[i]] = text |
| | | ibest_writer["timestamp"][key[i]] = time_stamp_postprocessed |
| | | ibest_writer["text"][key[i]] = text_postprocessed |
| | | else: |