| | |
| | | logging.info("asr_train_args: {}".format(asr_train_args)) |
| | | asr_model.to(dtype=getattr(torch, dtype)).eval() |
| | | |
| | | ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos) |
| | | if asr_model.ctc != None: |
| | | ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos) |
| | | scorers.update( |
| | | ctc=ctc |
| | | ) |
| | | token_list = asr_model.token_list |
| | | scorers.update( |
| | | ctc=ctc, |
| | | length_bonus=LengthBonus(len(token_list)), |
| | | ) |
| | | |
| | |
| | | self.converter = converter |
| | | self.tokenizer = tokenizer |
| | | is_use_lm = lm_weight != 0.0 and lm_file is not None |
| | | if ctc_weight == 0.0 and not is_use_lm: |
| | | if (ctc_weight == 0.0 or asr_model.ctc == None) and not is_use_lm: |
| | | beam_search = None |
| | | self.beam_search = beam_search |
| | | logging.info(f"Beam_search: {self.beam_search}") |
| | |
| | | token_int = hyp.yseq[1:last_pos].tolist() |
| | | |
| | | # remove blank symbol id, which is assumed to be 0 |
| | | token_int = list(filter(lambda x: x != 0, token_int)) |
| | | token_int = list(filter(lambda x: x != 0 and x != 2, token_int)) |
| | | |
| | | # Change integer-ids to tokens |
| | | token = self.converter.ids2tokens(token_int) |
| | |
| | | finish_count += 1 |
| | | # asr_utils.print_progress(finish_count / file_count) |
| | | if writer is not None: |
| | | ibest_writer["text"][key] = text |
| | | ibest_writer["text"][key] = text_postprocessed |
| | | |
| | | logging.info("decoding, utt: {}, predictions: {}".format(key, text)) |
| | | rtf_avg = "decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}".format(length_total, forward_time_total, 100 * forward_time_total / (length_total * lfr_factor)) |