From b75d1e89bb2f513a79bb07e9100ba1cd2bbcf40c Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期日, 09 六月 2024 00:32:57 +0800
Subject: [PATCH] fix bug

---
 funasr/models/transformer/scorers/ctc_prefix_score.py |   42 ++++++++++--------------------------------
 1 files changed, 10 insertions(+), 32 deletions(-)

diff --git a/funasr/models/transformer/scorers/ctc_prefix_score.py b/funasr/models/transformer/scorers/ctc_prefix_score.py
index 0c67ecd..ce0703e 100644
--- a/funasr/models/transformer/scorers/ctc_prefix_score.py
+++ b/funasr/models/transformer/scorers/ctc_prefix_score.py
@@ -38,11 +38,7 @@
         self.input_length = x.size(1)
         self.odim = x.size(2)
         self.dtype = x.dtype
-        self.device = (
-            torch.device("cuda:%d" % x.get_device())
-            if x.is_cuda
-            else torch.device("cpu")
-        )
+        self.device = torch.device("cuda:%d" % x.get_device()) if x.is_cuda else torch.device("cpu")
         # Pad the rest of posteriors in the batch
         # TODO(takaaki-hori): need a better way without for-loops
         for i, l in enumerate(xlens):
@@ -58,9 +54,7 @@
         # Setup CTC windowing
         self.margin = margin
         if margin > 0:
-            self.frame_ids = torch.arange(
-                self.input_length, dtype=self.dtype, device=self.device
-            )
+            self.frame_ids = torch.arange(self.input_length, dtype=self.dtype, device=self.device)
         # Base indices for index conversion
         self.idx_bh = None
         self.idx_b = torch.arange(self.batch, device=self.device)
@@ -98,18 +92,12 @@
 
         # select input dimensions for scoring
         if self.scoring_num > 0:
-            scoring_idmap = torch.full(
-                (n_bh, self.odim), -1, dtype=torch.long, device=self.device
-            )
+            scoring_idmap = torch.full((n_bh, self.odim), -1, dtype=torch.long, device=self.device)
             snum = self.scoring_num
             if self.idx_bh is None or n_bh > len(self.idx_bh):
                 self.idx_bh = torch.arange(n_bh, device=self.device).view(-1, 1)
-            scoring_idmap[self.idx_bh[:n_bh], scoring_ids] = torch.arange(
-                snum, device=self.device
-            )
-            scoring_idx = (
-                scoring_ids + self.idx_bo.repeat(1, n_hyps).view(-1, 1)
-            ).view(-1)
+            scoring_idmap[self.idx_bh[:n_bh], scoring_ids] = torch.arange(snum, device=self.device)
+            scoring_idx = (scoring_ids + self.idx_bo.repeat(1, n_hyps).view(-1, 1)).view(-1)
             x_ = torch.index_select(
                 self.x.view(2, -1, self.batch * self.odim), 2, scoring_idx
             ).view(2, -1, n_bh, snum)
@@ -156,9 +144,7 @@
         # compute forward probabilities log(r_t^n(h)) and log(r_t^b(h))
         for t in range(start, end):
             rp = r[t - 1]
-            rr = torch.stack([rp[0], log_phi[t - 1], rp[0], rp[1]]).view(
-                2, 2, n_bh, snum
-            )
+            rr = torch.stack([rp[0], log_phi[t - 1], rp[0], rp[1]]).view(2, 2, n_bh, snum)
             r[t] = torch.logsumexp(rr, 1) + x_[:, t]
 
         # compute log prefix probabilities log(psi)
@@ -205,9 +191,7 @@
         # convert ids to BHS space (S: scoring_num)
         if scoring_idmap is not None:
             snum = self.scoring_num
-            hyp_idx = (best_ids // self.odim + (self.idx_b * n_hyps).view(-1, 1)).view(
-                -1
-            )
+            hyp_idx = (best_ids // self.odim + (self.idx_b * n_hyps).view(-1, 1)).view(-1)
             label_ids = torch.fmod(best_ids, self.odim).view(-1)
             score_idx = scoring_idmap[hyp_idx, label_ids]
             score_idx[score_idx == -1] = 0
@@ -215,9 +199,7 @@
         else:
             snum = self.odim
         # select forward probabilities
-        r_new = torch.index_select(r.view(-1, 2, n_bh * snum), 2, vidx).view(
-            -1, 2, n_bh
-        )
+        r_new = torch.index_select(r.view(-1, 2, n_bh * snum), 2, vidx).view(-1, 2, n_bh)
         return r_new, s_new, f_min, f_max
 
     def extend_prob(self, x):
@@ -322,9 +304,7 @@
             r[output_length - 1] = self.logzero
 
         # prepare forward probabilities for the last label
-        r_sum = self.xp.logaddexp(
-            r_prev[:, 0], r_prev[:, 1]
-        )  # log(r_t^n(g) + r_t^b(g))
+        r_sum = self.xp.logaddexp(r_prev[:, 0], r_prev[:, 1])  # log(r_t^n(g) + r_t^b(g))
         last = y[-1]
         if output_length > 0 and last in cs:
             log_phi = self.xp.ndarray((self.input_length, len(cs)), dtype=np.float32)
@@ -339,9 +319,7 @@
         log_psi = r[start - 1, 0]
         for t in six.moves.range(start, self.input_length):
             r[t, 0] = self.xp.logaddexp(r[t - 1, 0], log_phi[t - 1]) + xs[t]
-            r[t, 1] = (
-                self.xp.logaddexp(r[t - 1, 0], r[t - 1, 1]) + self.x[t, self.blank]
-            )
+            r[t, 1] = self.xp.logaddexp(r[t - 1, 0], r[t - 1, 1]) + self.x[t, self.blank]
             log_psi = self.xp.logaddexp(log_psi, log_phi[t - 1] + xs[t])
 
         # get P(...eos|X) that ends with the prefix itself

--
Gitblit v1.9.1