| | |
| | | stats = {} |
| | | |
| | | # 1. Forward decoder |
| | | # ys_pad: [sos, task, lid, text, eos] |
| | | decoder_out = self.model.decoder( |
| | | x=ys_pad, xa=encoder_out, hlens=encoder_out_lens, ys_in_lens=ys_pad_lens |
| | | ) |
| | | |
| | | # 2. Compute attention loss |
| | | mask = torch.ones_like(ys_pad) * (-1) |
| | | ys_pad_mask = (ys_pad * target_mask + mask * (1 - target_mask)).to(torch.int64) |
| | | ys_pad_mask[ys_pad_mask == 0] = -1 |
| | | mask = torch.ones_like(ys_pad) * (-1) # [sos, task, lid, text, eos]: [-1, -1, -1, -1] |
| | | ys_pad_mask = (ys_pad * target_mask + mask * (1 - target_mask)).to( |
| | | torch.int64 |
| | | ) # [sos, task, lid, text, eos]: [0, 0, 1, 1, 1] + [-1, -1, 0, 0, 0] |
| | | ys_pad_mask[ys_pad_mask == 0] = -1 # [-1, -1, lid, text, eos] |
| | | # decoder_out: [sos, task, lid, text] |
| | | # ys_pad_mask: [-1, lid, text, eos] |
| | | loss_att = self.criterion_att(decoder_out[:, :-1, :], ys_pad_mask[:, 1:]) |
| | | |
| | | with torch.no_grad(): |
| | |
| | | data_type=kwargs.get("data_type", "sound"), |
| | | tokenizer=tokenizer, |
| | | ) |
| | | |
| | | if ( |
| | | isinstance(kwargs.get("data_type", None), (list, tuple)) |
| | | and len(kwargs.get("data_type", [])) > 1 |
| | | ): |
| | | audio_sample_list, text_token_int_list = audio_sample_list |
| | | text_token_int = text_token_int_list[0] |
| | | else: |
| | | text_token_int = None |
| | | |
| | | time2 = time.perf_counter() |
| | | meta_data["load_data"] = f"{time2 - time1:0.3f}" |
| | | speech, speech_lengths = extract_fbank( |
| | |
| | | speech[None, :, :].permute(0, 2, 1), speech_lengths |
| | | ) |
| | | |
| | | if text_token_int is not None: |
| | | i = 0 |
| | | results = [] |
| | | ibest_writer = None |
| | | if kwargs.get("output_dir") is not None: |
| | | if not hasattr(self, "writer"): |
| | | self.writer = DatadirWriter(kwargs.get("output_dir")) |
| | | ibest_writer = self.writer[f"1best_recog"] |
| | | |
| | | # 1. Forward decoder |
| | | ys_pad = torch.tensor(sos_int + text_token_int, dtype=torch.int64).to(kwargs["device"])[ |
| | | None, : |
| | | ] |
| | | ys_pad_lens = torch.tensor([len(sos_int + text_token_int)], dtype=torch.int64).to( |
| | | kwargs["device"] |
| | | )[None, :] |
| | | decoder_out = self.model.decoder( |
| | | x=ys_pad, xa=encoder_out, hlens=encoder_out_lens, ys_in_lens=ys_pad_lens |
| | | ) |
| | | |
| | | token_int = decoder_out.argmax(-1)[0, :].tolist() |
| | | text = tokenizer.decode(token_int) |
| | | |
| | | result_i = {"key": key[i], "text": text} |
| | | results.append(result_i) |
| | | |
| | | if ibest_writer is not None: |
| | | # ibest_writer["token"][key[i]] = " ".join(token) |
| | | ibest_writer["text"][key[i]] = text |
| | | return results, meta_data |
| | | |
| | | # c. Passed the encoder result and the beam search |
| | | nbest_hyps = self.beam_search( |
| | | x=encoder_out[0], |