| | |
| | | """RNN decoder module.""" |
| | | |
| | | import logging |
| | | import math |
| | | import random |
| | |
| | | from funasr.metrics import end_detect |
| | | from funasr.models.transformer.utils.nets_utils import mask_by_length |
| | | from funasr.models.transformer.utils.nets_utils import pad_list |
| | | from funasr.models.transformer.utils.nets_utils import th_accuracy |
| | | from funasr.metrics.compute_acc import th_accuracy |
| | | from funasr.models.transformer.utils.nets_utils import to_device |
| | | from funasr.models.language_model.rnn.attentions import att_to_numpy |
| | | |
| | |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | eprojs, |
| | | odim, |
| | | dtype, |
| | | dlayers, |
| | | dunits, |
| | | sos, |
| | | eos, |
| | | att, |
| | | verbose=0, |
| | | char_list=None, |
| | | labeldist=None, |
| | | lsm_weight=0.0, |
| | | sampling_probability=0.0, |
| | | dropout=0.0, |
| | | context_residual=False, |
| | | replace_sos=False, |
| | | num_encs=1, |
| | | self, |
| | | eprojs, |
| | | odim, |
| | | dtype, |
| | | dlayers, |
| | | dunits, |
| | | sos, |
| | | eos, |
| | | att, |
| | | verbose=0, |
| | | char_list=None, |
| | | labeldist=None, |
| | | lsm_weight=0.0, |
| | | sampling_probability=0.0, |
| | | dropout=0.0, |
| | | context_residual=False, |
| | | replace_sos=False, |
| | | num_encs=1, |
| | | ): |
| | | |
| | | torch.nn.Module.__init__(self) |
| | |
| | | self.decoder = torch.nn.ModuleList() |
| | | self.dropout_dec = torch.nn.ModuleList() |
| | | self.decoder += [ |
| | | torch.nn.LSTMCell(dunits + eprojs, dunits) |
| | | if self.dtype == "lstm" |
| | | else torch.nn.GRUCell(dunits + eprojs, dunits) |
| | | ( |
| | | torch.nn.LSTMCell(dunits + eprojs, dunits) |
| | | if self.dtype == "lstm" |
| | | else torch.nn.GRUCell(dunits + eprojs, dunits) |
| | | ) |
| | | ] |
| | | self.dropout_dec += [torch.nn.Dropout(p=dropout)] |
| | | for _ in six.moves.range(1, self.dlayers): |
| | | self.decoder += [ |
| | | torch.nn.LSTMCell(dunits, dunits) |
| | | if self.dtype == "lstm" |
| | | else torch.nn.GRUCell(dunits, dunits) |
| | | ( |
| | | torch.nn.LSTMCell(dunits, dunits) |
| | | if self.dtype == "lstm" |
| | | else torch.nn.GRUCell(dunits, dunits) |
| | | ) |
| | | ] |
| | | self.dropout_dec += [torch.nn.Dropout(p=dropout)] |
| | | # NOTE: dropout is applied only for the vertical connections |
| | |
| | | else: |
| | | z_list[0] = self.decoder[0](ey, z_prev[0]) |
| | | for i in six.moves.range(1, self.dlayers): |
| | | z_list[i] = self.decoder[i]( |
| | | self.dropout_dec[i - 1](z_list[i - 1]), z_prev[i] |
| | | ) |
| | | z_list[i] = self.decoder[i](self.dropout_dec[i - 1](z_list[i - 1]), z_prev[i]) |
| | | return z_list, c_list |
| | | |
| | | def forward(self, hs_pad, hlens, ys_pad, strm_idx=0, lang_ids=None): |
| | |
| | | ) |
| | | ) |
| | | logging.info( |
| | | self.__class__.__name__ |
| | | + " output lengths: " |
| | | + str([y.size(0) for y in ys_out]) |
| | | self.__class__.__name__ + " output lengths: " + str([y.size(0) for y in ys_out]) |
| | | ) |
| | | |
| | | # initialization |
| | |
| | | ys_hat = y_all.view(batch, olength, -1) |
| | | ys_true = ys_out_pad |
| | | for (i, y_hat), y_true in zip( |
| | | enumerate(ys_hat.detach().cpu().numpy()), ys_true.detach().cpu().numpy() |
| | | enumerate(ys_hat.detach().cpu().numpy()), ys_true.detach().cpu().numpy() |
| | | ): |
| | | if i == MAX_DECODER_OUTPUT: |
| | | break |
| | |
| | | weights_ctc_dec = recog_args.weights_ctc_dec / np.sum( |
| | | recog_args.weights_ctc_dec |
| | | ) # normalize |
| | | logging.info( |
| | | "ctc weights (decoding): " + " ".join([str(x) for x in weights_ctc_dec]) |
| | | ) |
| | | logging.info("ctc weights (decoding): " + " ".join([str(x) for x in weights_ctc_dec])) |
| | | else: |
| | | weights_ctc_dec = [1.0] |
| | | |
| | |
| | | hyp["a_prev"][self.num_encs], |
| | | ) |
| | | ey = torch.cat((ey, att_c), dim=1) # utt(1) x (zdim + hdim) |
| | | z_list, c_list = self.rnn_forward( |
| | | ey, z_list, c_list, hyp["z_prev"], hyp["c_prev"] |
| | | ) |
| | | z_list, c_list = self.rnn_forward(ey, z_list, c_list, hyp["z_prev"], hyp["c_prev"]) |
| | | |
| | | # get nbest local scores and their ids |
| | | if self.context_residual: |
| | |
| | | local_att_scores = F.log_softmax(logits, dim=1) |
| | | if rnnlm: |
| | | rnnlm_state, local_lm_scores = rnnlm.predict(hyp["rnnlm_prev"], vy) |
| | | local_scores = ( |
| | | local_att_scores + recog_args.lm_weight * local_lm_scores |
| | | ) |
| | | local_scores = local_att_scores + recog_args.lm_weight * local_lm_scores |
| | | else: |
| | | local_scores = local_att_scores |
| | | |
| | |
| | | ctc_scores[idx], ctc_states[idx] = ctc_prefix_score[idx]( |
| | | hyp["yseq"], local_best_ids[0], hyp["ctc_state_prev"][idx] |
| | | ) |
| | | local_scores = (1.0 - ctc_weight) * local_att_scores[ |
| | | :, local_best_ids[0] |
| | | ] |
| | | local_scores = (1.0 - ctc_weight) * local_att_scores[:, local_best_ids[0]] |
| | | if self.num_encs == 1: |
| | | local_scores += ctc_weight * torch.from_numpy( |
| | | ctc_scores[0] - hyp["ctc_score_prev"][0] |
| | |
| | | else: |
| | | for idx in range(self.num_encs): |
| | | local_scores += ( |
| | | ctc_weight |
| | | * weights_ctc_dec[idx] |
| | | * torch.from_numpy( |
| | | ctc_scores[idx] - hyp["ctc_score_prev"][idx] |
| | | ) |
| | | ctc_weight |
| | | * weights_ctc_dec[idx] |
| | | * torch.from_numpy(ctc_scores[idx] - hyp["ctc_score_prev"][idx]) |
| | | ) |
| | | if rnnlm: |
| | | local_scores += ( |
| | | recog_args.lm_weight * local_lm_scores[:, local_best_ids[0]] |
| | | ) |
| | | local_best_scores, joint_best_ids = torch.topk( |
| | | local_scores, beam, dim=1 |
| | | ) |
| | | local_scores += recog_args.lm_weight * local_lm_scores[:, local_best_ids[0]] |
| | | local_best_scores, joint_best_ids = torch.topk(local_scores, beam, dim=1) |
| | | local_best_ids = local_best_ids[:, joint_best_ids[0]] |
| | | else: |
| | | local_best_scores, local_best_ids = torch.topk( |
| | | local_scores, beam, dim=1 |
| | | ) |
| | | local_best_scores, local_best_ids = torch.topk(local_scores, beam, dim=1) |
| | | |
| | | for j in six.moves.range(beam): |
| | | new_hyp = {} |
| | |
| | | if self.num_encs == 1: |
| | | new_hyp["a_prev"] = att_w[:] |
| | | else: |
| | | new_hyp["a_prev"] = [ |
| | | att_w_list[idx][:] for idx in range(self.num_encs + 1) |
| | | ] |
| | | new_hyp["a_prev"] = [att_w_list[idx][:] for idx in range(self.num_encs + 1)] |
| | | new_hyp["score"] = hyp["score"] + local_best_scores[0, j] |
| | | new_hyp["yseq"] = [0] * (1 + len(hyp["yseq"])) |
| | | new_hyp["yseq"][: len(hyp["yseq"])] = hyp["yseq"] |
| | |
| | | new_hyp["rnnlm_prev"] = rnnlm_state |
| | | if lpz[0] is not None: |
| | | new_hyp["ctc_state_prev"] = [ |
| | | ctc_states[idx][joint_best_ids[0, j]] |
| | | for idx in range(self.num_encs) |
| | | ctc_states[idx][joint_best_ids[0, j]] for idx in range(self.num_encs) |
| | | ] |
| | | new_hyp["ctc_score_prev"] = [ |
| | | ctc_scores[idx][joint_best_ids[0, j]] |
| | | for idx in range(self.num_encs) |
| | | ctc_scores[idx][joint_best_ids[0, j]] for idx in range(self.num_encs) |
| | | ] |
| | | # will be (2 x beam) hyps at most |
| | | hyps_best_kept.append(new_hyp) |
| | | |
| | | hyps_best_kept = sorted( |
| | | hyps_best_kept, key=lambda x: x["score"], reverse=True |
| | | )[:beam] |
| | | hyps_best_kept = sorted(hyps_best_kept, key=lambda x: x["score"], reverse=True)[ |
| | | :beam |
| | | ] |
| | | |
| | | # sort and get nbest |
| | | hyps = hyps_best_kept |
| | | logging.debug("number of pruned hypotheses: " + str(len(hyps))) |
| | | logging.debug( |
| | | "best hypo: " |
| | | + "".join([char_list[int(x)] for x in hyps[0]["yseq"][1:]]) |
| | | ) |
| | | logging.debug("best hypo: " + "".join([char_list[int(x)] for x in hyps[0]["yseq"][1:]])) |
| | | |
| | | # add eos in the final loop to avoid that there are no ended hyps |
| | | if i == maxlen - 1: |
| | |
| | | if len(hyp["yseq"]) > minlen: |
| | | hyp["score"] += (i + 1) * penalty |
| | | if rnnlm: # Word LM needs to add final <eos> score |
| | | hyp["score"] += recog_args.lm_weight * rnnlm.final( |
| | | hyp["rnnlm_prev"] |
| | | ) |
| | | hyp["score"] += recog_args.lm_weight * rnnlm.final(hyp["rnnlm_prev"]) |
| | | ended_hyps.append(hyp) |
| | | else: |
| | | remained_hyps.append(hyp) |
| | |
| | | break |
| | | |
| | | for hyp in hyps: |
| | | logging.debug( |
| | | "hypo: " + "".join([char_list[int(x)] for x in hyp["yseq"][1:]]) |
| | | ) |
| | | logging.debug("hypo: " + "".join([char_list[int(x)] for x in hyp["yseq"][1:]])) |
| | | |
| | | logging.debug("number of ended hypotheses: " + str(len(ended_hyps))) |
| | | |
| | | nbest_hyps = sorted(ended_hyps, key=lambda x: x["score"], reverse=True)[ |
| | | : min(len(ended_hyps), recog_args.nbest) |
| | | ] |
| | | : min(len(ended_hyps), recog_args.nbest) |
| | | ] |
| | | |
| | | # check number of hypotheses |
| | | if len(nbest_hyps) == 0: |
| | | logging.warning( |
| | | "there is no N-best results, " |
| | | "perform recognition again with smaller minlenratio." |
| | | "there is no N-best results, " "perform recognition again with smaller minlenratio." |
| | | ) |
| | | # should copy because Namespace will be overwritten globally |
| | | recog_args = Namespace(**vars(recog_args)) |
| | |
| | | return nbest_hyps |
| | | |
| | | def recognize_beam_batch( |
| | | self, |
| | | h, |
| | | hlens, |
| | | lpz, |
| | | recog_args, |
| | | char_list, |
| | | rnnlm=None, |
| | | normalize_score=True, |
| | | strm_idx=0, |
| | | lang_ids=None, |
| | | self, |
| | | h, |
| | | hlens, |
| | | lpz, |
| | | recog_args, |
| | | char_list, |
| | | rnnlm=None, |
| | | normalize_score=True, |
| | | strm_idx=0, |
| | | lang_ids=None, |
| | | ): |
| | | # to support mutiple encoder asr mode, in single encoder mode, |
| | | # convert torch.Tensor to List of torch.Tensor |
| | |
| | | weights_ctc_dec = recog_args.weights_ctc_dec / np.sum( |
| | | recog_args.weights_ctc_dec |
| | | ) # normalize |
| | | logging.info( |
| | | "ctc weights (decoding): " + " ".join([str(x) for x in weights_ctc_dec]) |
| | | ) |
| | | logging.info("ctc weights (decoding): " + " ".join([str(x) for x in weights_ctc_dec])) |
| | | else: |
| | | weights_ctc_dec = [1.0] |
| | | |
| | |
| | | logging.info("min output length: " + str(minlen)) |
| | | |
| | | # initialization |
| | | c_prev = [ |
| | | to_device(h[0], torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers) |
| | | ] |
| | | z_prev = [ |
| | | to_device(h[0], torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers) |
| | | ] |
| | | c_list = [ |
| | | to_device(h[0], torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers) |
| | | ] |
| | | z_list = [ |
| | | to_device(h[0], torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers) |
| | | ] |
| | | c_prev = [to_device(h[0], torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers)] |
| | | z_prev = [to_device(h[0], torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers)] |
| | | c_list = [to_device(h[0], torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers)] |
| | | z_list = [to_device(h[0], torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers)] |
| | | vscores = to_device(h[0], torch.zeros(batch, beam)) |
| | | |
| | | rnnlm_state = None |
| | |
| | | if self.replace_sos and recog_args.tgt_lang: |
| | | logging.info("<sos> index: " + str(char_list.index(recog_args.tgt_lang))) |
| | | logging.info("<sos> mark: " + recog_args.tgt_lang) |
| | | yseq = [ |
| | | [char_list.index(recog_args.tgt_lang)] for _ in six.moves.range(n_bb) |
| | | ] |
| | | yseq = [[char_list.index(recog_args.tgt_lang)] for _ in six.moves.range(n_bb)] |
| | | elif lang_ids is not None: |
| | | # NOTE: used for evaluation during training |
| | | yseq = [ |
| | | [lang_ids[b // recog_args.beam_size]] for b in six.moves.range(n_bb) |
| | | ] |
| | | yseq = [[lang_ids[b // recog_args.beam_size]] for b in six.moves.range(n_bb)] |
| | | else: |
| | | logging.info("<sos> index: " + str(self.sos)) |
| | | logging.info("<sos> mark: " + char_list[self.sos]) |
| | |
| | | ] |
| | | exp_hlens = [exp_hlens[idx].view(-1).tolist() for idx in range(self.num_encs)] |
| | | exp_h = [ |
| | | h[idx].unsqueeze(1).repeat(1, beam, 1, 1).contiguous() |
| | | for idx in range(self.num_encs) |
| | | h[idx].unsqueeze(1).repeat(1, beam, 1, 1).contiguous() for idx in range(self.num_encs) |
| | | ] |
| | | exp_h = [ |
| | | exp_h[idx].view(n_bb, h[idx].size()[1], h[idx].size()[2]) |
| | |
| | | |
| | | if lpz[0] is not None: |
| | | scoring_num = min( |
| | | int(beam * CTC_SCORING_RATIO) |
| | | if att_weight > 0.0 and not lpz[0].is_cuda |
| | | else 0, |
| | | int(beam * CTC_SCORING_RATIO) if att_weight > 0.0 and not lpz[0].is_cuda else 0, |
| | | lpz[0].size(-1), |
| | | ) |
| | | ctc_scorer = [ |
| | |
| | | # attention decoder |
| | | z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_prev, c_prev) |
| | | if self.context_residual: |
| | | logits = self.output( |
| | | torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1) |
| | | ) |
| | | logits = self.output(torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1)) |
| | | else: |
| | | logits = self.output(self.dropout_dec[-1](z_list[-1])) |
| | | local_scores = att_weight * F.log_softmax(logits, dim=1) |
| | |
| | | if ctc_scorer[0]: |
| | | local_scores[:, 0] = self.logzero # avoid choosing blank |
| | | part_ids = ( |
| | | torch.topk(local_scores, scoring_num, dim=-1)[1] |
| | | if scoring_num > 0 |
| | | else None |
| | | torch.topk(local_scores, scoring_num, dim=-1)[1] if scoring_num > 0 else None |
| | | ) |
| | | for idx in range(self.num_encs): |
| | | att_w = att_w_list[idx] |
| | |
| | | yseq, ctc_state[idx], part_ids, att_w_ |
| | | ) |
| | | local_scores = ( |
| | | local_scores |
| | | + ctc_weight * weights_ctc_dec[idx] * local_ctc_scores |
| | | local_scores + ctc_weight * weights_ctc_dec[idx] * local_ctc_scores |
| | | ) |
| | | |
| | | local_scores = local_scores.view(batch, beam, self.odim) |
| | |
| | | |
| | | # global pruning |
| | | accum_best_scores, accum_best_ids = torch.topk(vscores, beam, 1) |
| | | accum_odim_ids = ( |
| | | torch.fmod(accum_best_ids, self.odim).view(-1).data.cpu().tolist() |
| | | ) |
| | | accum_odim_ids = torch.fmod(accum_best_ids, self.odim).view(-1).data.cpu().tolist() |
| | | accum_padded_beam_ids = ( |
| | | (accum_best_ids // self.odim + pad_b).view(-1).data.cpu().tolist() |
| | | ) |
| | |
| | | ] |
| | | else: |
| | | # handle the case of location_recurrent when return is a tuple |
| | | _a_prev_ = torch.index_select( |
| | | att_w_list[idx][0].view(n_bb, -1), 0, vidx |
| | | ) |
| | | _h_prev_ = torch.index_select( |
| | | att_w_list[idx][1][0].view(n_bb, -1), 0, vidx |
| | | ) |
| | | _c_prev_ = torch.index_select( |
| | | att_w_list[idx][1][1].view(n_bb, -1), 0, vidx |
| | | ) |
| | | _a_prev_ = torch.index_select(att_w_list[idx][0].view(n_bb, -1), 0, vidx) |
| | | _h_prev_ = torch.index_select(att_w_list[idx][1][0].view(n_bb, -1), 0, vidx) |
| | | _c_prev_ = torch.index_select(att_w_list[idx][1][1].view(n_bb, -1), 0, vidx) |
| | | _a_prev = (_a_prev_, (_h_prev_, _c_prev_)) |
| | | a_prev.append(_a_prev) |
| | | z_prev = [ |
| | | torch.index_select(z_list[li].view(n_bb, -1), 0, vidx) |
| | | for li in range(self.dlayers) |
| | | torch.index_select(z_list[li].view(n_bb, -1), 0, vidx) for li in range(self.dlayers) |
| | | ] |
| | | c_prev = [ |
| | | torch.index_select(c_list[li].view(n_bb, -1), 0, vidx) |
| | | for li in range(self.dlayers) |
| | | torch.index_select(c_list[li].view(n_bb, -1), 0, vidx) for li in range(self.dlayers) |
| | | ] |
| | | |
| | | # pick ended hyps |
| | |
| | | _vscore = None |
| | | if eos_vscores[samp_i, beam_j] > thr[samp_i]: |
| | | yk = y_prev[k][:] |
| | | if len(yk) <= min( |
| | | hlens[idx][samp_i] for idx in range(self.num_encs) |
| | | ): |
| | | if len(yk) <= min(hlens[idx][samp_i] for idx in range(self.num_encs)): |
| | | _vscore = eos_vscores[samp_i][beam_j] + penalty_i |
| | | elif i == maxlen - 1: |
| | | yk = yseq[k][:] |
| | |
| | | if _vscore: |
| | | yk.append(self.eos) |
| | | if rnnlm: |
| | | _vscore += recog_args.lm_weight * rnnlm.final( |
| | | rnnlm_state, index=k |
| | | ) |
| | | _vscore += recog_args.lm_weight * rnnlm.final(rnnlm_state, index=k) |
| | | _score = _vscore.data.cpu().numpy() |
| | | ended_hyps[samp_i].append( |
| | | {"yseq": yk, "vscore": _vscore, "score": _score} |
| | |
| | | |
| | | torch.cuda.empty_cache() |
| | | |
| | | dummy_hyps = [ |
| | | {"yseq": [self.sos, self.eos], "score": np.array([-float("inf")])} |
| | | ] |
| | | dummy_hyps = [{"yseq": [self.sos, self.eos], "score": np.array([-float("inf")])}] |
| | | ended_hyps = [ |
| | | ended_hyps[samp_i] if len(ended_hyps[samp_i]) != 0 else dummy_hyps |
| | | for samp_i in six.moves.range(batch) |
| | |
| | | |
| | | nbest_hyps = [ |
| | | sorted(ended_hyps[samp_i], key=lambda x: x["score"], reverse=True)[ |
| | | : min(len(ended_hyps[samp_i]), recog_args.nbest) |
| | | : min(len(ended_hyps[samp_i]), recog_args.nbest) |
| | | ] |
| | | for samp_i in six.moves.range(batch) |
| | | ] |
| | |
| | | state["a_prev"][self.num_encs], |
| | | ) |
| | | ey = torch.cat((ey, att_c), dim=1) # utt(1) x (zdim + hdim) |
| | | z_list, c_list = self.rnn_forward( |
| | | ey, z_list, c_list, state["z_prev"], state["c_prev"] |
| | | ) |
| | | z_list, c_list = self.rnn_forward(ey, z_list, c_list, state["z_prev"], state["c_prev"]) |
| | | if self.context_residual: |
| | | logits = self.output( |
| | | torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1) |
| | | ) |
| | | logits = self.output(torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1)) |
| | | else: |
| | | logits = self.output(self.dropout_dec[-1](z_list[-1])) |
| | | logp = F.log_softmax(logits, dim=1).squeeze(0) |