| | |
| | | self.input_length = x.size(1) |
| | | self.odim = x.size(2) |
| | | self.dtype = x.dtype |
| | | self.device = ( |
| | | torch.device("cuda:%d" % x.get_device()) |
| | | if x.is_cuda |
| | | else torch.device("cpu") |
| | | ) |
| | | self.device = torch.device("cuda:%d" % x.get_device()) if x.is_cuda else torch.device("cpu") |
| | | # Pad the rest of posteriors in the batch |
| | | # TODO(takaaki-hori): need a better way without for-loops |
| | | for i, l in enumerate(xlens): |
| | |
| | | # Setup CTC windowing |
| | | self.margin = margin |
| | | if margin > 0: |
| | | self.frame_ids = torch.arange( |
| | | self.input_length, dtype=self.dtype, device=self.device |
| | | ) |
| | | self.frame_ids = torch.arange(self.input_length, dtype=self.dtype, device=self.device) |
| | | # Base indices for index conversion |
| | | self.idx_bh = None |
| | | self.idx_b = torch.arange(self.batch, device=self.device) |
| | |
| | | |
| | | # select input dimensions for scoring |
| | | if self.scoring_num > 0: |
| | | scoring_idmap = torch.full( |
| | | (n_bh, self.odim), -1, dtype=torch.long, device=self.device |
| | | ) |
| | | scoring_idmap = torch.full((n_bh, self.odim), -1, dtype=torch.long, device=self.device) |
| | | snum = self.scoring_num |
| | | if self.idx_bh is None or n_bh > len(self.idx_bh): |
| | | self.idx_bh = torch.arange(n_bh, device=self.device).view(-1, 1) |
| | | scoring_idmap[self.idx_bh[:n_bh], scoring_ids] = torch.arange( |
| | | snum, device=self.device |
| | | ) |
| | | scoring_idx = ( |
| | | scoring_ids + self.idx_bo.repeat(1, n_hyps).view(-1, 1) |
| | | ).view(-1) |
| | | scoring_idmap[self.idx_bh[:n_bh], scoring_ids] = torch.arange(snum, device=self.device) |
| | | scoring_idx = (scoring_ids + self.idx_bo.repeat(1, n_hyps).view(-1, 1)).view(-1) |
| | | x_ = torch.index_select( |
| | | self.x.view(2, -1, self.batch * self.odim), 2, scoring_idx |
| | | ).view(2, -1, n_bh, snum) |
| | |
| | | # compute forward probabilities log(r_t^n(h)) and log(r_t^b(h)) |
| | | for t in range(start, end): |
| | | rp = r[t - 1] |
| | | rr = torch.stack([rp[0], log_phi[t - 1], rp[0], rp[1]]).view( |
| | | 2, 2, n_bh, snum |
| | | ) |
| | | rr = torch.stack([rp[0], log_phi[t - 1], rp[0], rp[1]]).view(2, 2, n_bh, snum) |
| | | r[t] = torch.logsumexp(rr, 1) + x_[:, t] |
| | | |
| | | # compute log prefix probabilities log(psi) |
| | |
| | | # convert ids to BHS space (S: scoring_num) |
| | | if scoring_idmap is not None: |
| | | snum = self.scoring_num |
| | | hyp_idx = (best_ids // self.odim + (self.idx_b * n_hyps).view(-1, 1)).view( |
| | | -1 |
| | | ) |
| | | hyp_idx = (best_ids // self.odim + (self.idx_b * n_hyps).view(-1, 1)).view(-1) |
| | | label_ids = torch.fmod(best_ids, self.odim).view(-1) |
| | | score_idx = scoring_idmap[hyp_idx, label_ids] |
| | | score_idx[score_idx == -1] = 0 |
| | |
| | | else: |
| | | snum = self.odim |
| | | # select forward probabilities |
| | | r_new = torch.index_select(r.view(-1, 2, n_bh * snum), 2, vidx).view( |
| | | -1, 2, n_bh |
| | | ) |
| | | r_new = torch.index_select(r.view(-1, 2, n_bh * snum), 2, vidx).view(-1, 2, n_bh) |
| | | return r_new, s_new, f_min, f_max |
| | | |
| | | def extend_prob(self, x): |
| | |
| | | r[output_length - 1] = self.logzero |
| | | |
| | | # prepare forward probabilities for the last label |
| | | r_sum = self.xp.logaddexp( |
| | | r_prev[:, 0], r_prev[:, 1] |
| | | ) # log(r_t^n(g) + r_t^b(g)) |
| | | r_sum = self.xp.logaddexp(r_prev[:, 0], r_prev[:, 1]) # log(r_t^n(g) + r_t^b(g)) |
| | | last = y[-1] |
| | | if output_length > 0 and last in cs: |
| | | log_phi = self.xp.ndarray((self.input_length, len(cs)), dtype=np.float32) |
| | |
| | | log_psi = r[start - 1, 0] |
| | | for t in six.moves.range(start, self.input_length): |
| | | r[t, 0] = self.xp.logaddexp(r[t - 1, 0], log_phi[t - 1]) + xs[t] |
| | | r[t, 1] = ( |
| | | self.xp.logaddexp(r[t - 1, 0], r[t - 1, 1]) + self.x[t, self.blank] |
| | | ) |
| | | r[t, 1] = self.xp.logaddexp(r[t - 1, 0], r[t - 1, 1]) + self.x[t, self.blank] |
| | | log_psi = self.xp.logaddexp(log_psi, log_phi[t - 1] + xs[t]) |
| | | |
| | | # get P(...eos|X) that ends with the prefix itself |