| | |
| | | |
| | | self.vocab_size = decoder.vocab_size |
| | | |
| | | assert beam_size <= self.vocab_size, ( |
| | | "beam_size (%d) should be smaller than or equal to vocabulary size (%d)." |
| | | % ( |
| | | beam_size, |
| | | self.vocab_size, |
| | | ) |
| | | assert ( |
| | | beam_size <= self.vocab_size |
| | | ), "beam_size (%d) should be smaller than or equal to vocabulary size (%d)." % ( |
| | | beam_size, |
| | | self.vocab_size, |
| | | ) |
| | | self.beam_size = beam_size |
| | | |
| | | if search_type == "default": |
| | | self.search_algorithm = self.default_beam_search |
| | | elif search_type == "tsd": |
| | | assert max_sym_exp > 1, "max_sym_exp (%d) should be greater than one." % ( |
| | | max_sym_exp |
| | | ) |
| | | assert max_sym_exp > 1, "max_sym_exp (%d) should be greater than one." % (max_sym_exp) |
| | | self.max_sym_exp = max_sym_exp |
| | | |
| | | self.search_algorithm = self.time_sync_decoding |
| | |
| | | |
| | | self.search_algorithm = self.modified_adaptive_expansion_search |
| | | else: |
| | | raise NotImplementedError( |
| | | "Specified search type (%s) is not supported." % search_type |
| | | ) |
| | | raise NotImplementedError("Specified search type (%s) is not supported." % search_type) |
| | | |
| | | self.use_lm = lm is not None |
| | | |
| | |
| | | k_expansions = [] |
| | | |
| | | for i, hyp in enumerate(hyps): |
| | | hyp_i = [ |
| | | (int(k), hyp.score + float(v)) |
| | | for k, v in zip(topk_idx[i], topk_logp[i]) |
| | | ] |
| | | hyp_i = [(int(k), hyp.score + float(v)) for k, v in zip(topk_idx[i], topk_logp[i])] |
| | | k_best_exp = max(hyp_i, key=lambda x: x[1])[1] |
| | | |
| | | k_expansions.append( |
| | | sorted( |
| | | filter( |
| | | lambda x: (k_best_exp - self.expansion_gamma) <= x[1], hyp_i |
| | | ), |
| | | filter(lambda x: (k_best_exp - self.expansion_gamma) <= x[1], hyp_i), |
| | | key=lambda x: x[1], |
| | | reverse=True, |
| | | ) |
| | |
| | | |
| | | if self.use_lm: |
| | | lm_scores, lm_state = self.lm.score( |
| | | torch.LongTensor( |
| | | [self.sos] + max_hyp.yseq[1:], device=self.decoder.device |
| | | ), |
| | | torch.LongTensor([self.sos] + max_hyp.yseq[1:], device=self.decoder.device), |
| | | max_hyp.lm_state, |
| | | None, |
| | | ) |
| | |
| | | break |
| | | |
| | | return kept_hyps |
| | | |
| | | |
| | | def align_length_sync_decoding( |
| | | self, |
| | | enc_out: torch.Tensor, |