From 33a9e08dc9b65abc3f3e18d14253f95c79e0f749 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 08 五月 2024 19:20:43 +0800
Subject: [PATCH] dynamic batch

---
 funasr/models/scama/beam_search.py |  121 +++++++++++++++++++++++-----------------
 1 files changed, 69 insertions(+), 52 deletions(-)

diff --git a/funasr/models/scama/beam_search.py b/funasr/models/scama/beam_search.py
index b8aa876..259f483 100644
--- a/funasr/models/scama/beam_search.py
+++ b/funasr/models/scama/beam_search.py
@@ -33,7 +33,6 @@
         )._asdict()
 
 
-
 class BeamSearchScama(torch.nn.Module):
     """Beam search implementation."""
 
@@ -151,7 +150,8 @@
         return torch.cat((xs, x))
 
     def score_full(
-        self, hyp: Hypothesis,
+        self,
+        hyp: Hypothesis,
         x: torch.Tensor,
         x_mask: torch.Tensor = None,
         pre_acoustic_embeds: torch.Tensor = None,
@@ -173,7 +173,9 @@
         scores = dict()
         states = dict()
         for k, d in self.full_scorers.items():
-            scores[k], states[k] = d.score(hyp.yseq, hyp.states[k], x, x_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds)
+            scores[k], states[k] = d.score(
+                hyp.yseq, hyp.states[k], x, x_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds
+            )
         return scores, states
 
     def score_partial(
@@ -283,7 +285,8 @@
         return new_states
 
     def search(
-        self, running_hyps: List[Hypothesis],
+        self,
+        running_hyps: List[Hypothesis],
         x: torch.Tensor,
         x_mask: torch.Tensor = None,
         pre_acoustic_embeds: torch.Tensor = None,
@@ -303,7 +306,9 @@
         for hyp in running_hyps:
             # scoring
             weighted_scores = torch.zeros(self.n_vocab, dtype=x.dtype, device=x.device)
-            scores, states = self.score_full(hyp, x, x_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds)
+            scores, states = self.score_full(
+                hyp, x, x_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds
+            )
             for k in self.full_scorers:
                 weighted_scores += self.weights[k] * scores[k]
             # partial scoring
@@ -327,9 +332,7 @@
                     Hypothesis(
                         score=weighted_scores[j],
                         yseq=self.append_token(hyp.yseq, j),
-                        scores=self.merge_scores(
-                            hyp.scores, scores, j, part_scores, part_j
-                        ),
+                        scores=self.merge_scores(hyp.scores, scores, j, part_scores, part_j),
                         states=self.merge_states(states, part_states, part_j),
                     )
                 )
@@ -341,7 +344,8 @@
         return best_hyps
 
     def forward(
-        self, x: torch.Tensor,
+        self,
+        x: torch.Tensor,
         scama_mask: torch.Tensor = None,
         pre_acoustic_embeds: torch.Tensor = None,
         maxlenratio: float = 0.0,
@@ -386,20 +390,26 @@
             mask_enc = None
             if scama_mask is not None:
                 token_num_predictor = scama_mask.size(1)
-                token_id_slice = min(i, token_num_predictor-1)
-                mask_enc = scama_mask[:, token_id_slice:token_id_slice+1, :]
+                token_id_slice = min(i, token_num_predictor - 1)
+                mask_enc = scama_mask[:, token_id_slice : token_id_slice + 1, :]
                 # if mask_enc.size(1) == 0:
                 #     mask_enc = scama_mask[:, -2:-1, :]
                 #     # mask_enc = torch.zeros_like(mask_enc)
             pre_acoustic_embeds_cur = None
             if pre_acoustic_embeds is not None:
                 b, t, d = pre_acoustic_embeds.size()
-                pad = torch.zeros((b, 1, d), dtype=pre_acoustic_embeds.dtype).to(device=pre_acoustic_embeds.device)
+                pad = torch.zeros((b, 1, d), dtype=pre_acoustic_embeds.dtype).to(
+                    device=pre_acoustic_embeds.device
+                )
                 pre_acoustic_embeds = torch.cat((pre_acoustic_embeds, pad), dim=1)
                 token_id_slice = min(i, t)
-                pre_acoustic_embeds_cur = pre_acoustic_embeds[:, token_id_slice:token_id_slice+1, :]
+                pre_acoustic_embeds_cur = pre_acoustic_embeds[
+                    :, token_id_slice : token_id_slice + 1, :
+                ]
 
-            best = self.search(running_hyps, x, x_mask=mask_enc, pre_acoustic_embeds=pre_acoustic_embeds_cur)
+            best = self.search(
+                running_hyps, x, x_mask=mask_enc, pre_acoustic_embeds=pre_acoustic_embeds_cur
+            )
             # post process of one iteration
             running_hyps = self.post_process(i, maxlen, maxlenratio, best, ended_hyps)
             # end detection
@@ -416,8 +426,7 @@
         # check the number of hypotheses reaching to eos
         if len(nbest_hyps) == 0:
             logging.warning(
-                "there is no N-best results, perform recognition "
-                "again with smaller minlenratio."
+                "there is no N-best results, perform recognition " "again with smaller minlenratio."
             )
             return (
                 []
@@ -431,17 +440,13 @@
             logging.debug("nbest: y: {}, yseq: {}, score: {}".format(x.yseq, yseq, x.score))
         best = nbest_hyps[0]
         for k, v in best.scores.items():
-            logging.info(
-                f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}"
-            )
+            logging.info(f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}")
         logging.info(f"total log probability: {best.score:.2f}")
         logging.info(f"normalized log probability: {best.score / len(best.yseq):.2f}")
         logging.info(f"total number of ended hypotheses: {len(nbest_hyps)}")
         if self.token_list is not None:
             logging.info(
-                "best hypo: "
-                + "".join([self.token_list[x] for x in best.yseq[1:-1]])
-                + "\n"
+                "best hypo: " + "".join([self.token_list[x] for x in best.yseq[1:-1]]) + "\n"
             )
         return nbest_hyps
 
@@ -469,15 +474,13 @@
         logging.debug(f"the number of running hypotheses: {len(running_hyps)}")
         if self.token_list is not None:
             logging.debug(
-                "best hypo: "
-                + "".join([self.token_list[x] for x in running_hyps[0].yseq[1:]])
+                "best hypo: " + "".join([self.token_list[x] for x in running_hyps[0].yseq[1:]])
             )
         # add eos in the final loop to avoid that there are no ended hyps
         if i == maxlen - 1:
             logging.info("adding <eos> in the last position in the loop")
             running_hyps = [
-                h._replace(yseq=self.append_token(h.yseq, self.eos))
-                for h in running_hyps
+                h._replace(yseq=self.append_token(h.yseq, self.eos)) for h in running_hyps
             ]
 
         # add ended hypotheses to a final list, and removed them from current hypotheses
@@ -494,6 +497,7 @@
             else:
                 remained_hyps.append(hyp)
         return remained_hyps
+
 
 class BeamSearchScamaStreaming(torch.nn.Module):
     """Beam search implementation."""
@@ -612,11 +616,12 @@
         return torch.cat((xs, x))
 
     def score_full(
-        self, hyp: Hypothesis,
+        self,
+        hyp: Hypothesis,
         x: torch.Tensor,
         x_mask: torch.Tensor = None,
         pre_acoustic_embeds: torch.Tensor = None,
-        cache: dict={},
+        cache: dict = {},
     ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
         """Score new hypothesis by `self.full_scorers`.
 
@@ -635,7 +640,14 @@
         scores = dict()
         states = dict()
         for k, d in self.full_scorers.items():
-            scores[k], states[k] = d.score(hyp.yseq, hyp.states[k], x, x_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds, cache=cache)
+            scores[k], states[k] = d.score(
+                hyp.yseq,
+                hyp.states[k],
+                x,
+                x_mask=x_mask,
+                pre_acoustic_embeds=pre_acoustic_embeds,
+                cache=cache,
+            )
         return scores, states
 
     def score_partial(
@@ -745,11 +757,12 @@
         return new_states
 
     def search(
-        self, running_hyps: List[Hypothesis],
+        self,
+        running_hyps: List[Hypothesis],
         x: torch.Tensor,
         x_mask: torch.Tensor = None,
         pre_acoustic_embeds: torch.Tensor = None,
-        cache: dict={},
+        cache: dict = {},
     ) -> List[Hypothesis]:
         """Search new tokens for running hypotheses and encoded speech x.
 
@@ -766,7 +779,9 @@
         for hyp in running_hyps:
             # scoring
             weighted_scores = torch.zeros(self.n_vocab, dtype=x.dtype, device=x.device)
-            scores, states = self.score_full(hyp, x, x_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds, cache=cache)
+            scores, states = self.score_full(
+                hyp, x, x_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds, cache=cache
+            )
             for k in self.full_scorers:
                 weighted_scores += self.weights[k] * scores[k]
             # partial scoring
@@ -790,9 +805,7 @@
                     Hypothesis(
                         score=weighted_scores[j],
                         yseq=self.append_token(hyp.yseq, j),
-                        scores=self.merge_scores(
-                            hyp.scores, scores, j, part_scores, part_j
-                        ),
+                        scores=self.merge_scores(hyp.scores, scores, j, part_scores, part_j),
                         states=self.merge_states(states, part_states, part_j),
                     )
                 )
@@ -804,14 +817,15 @@
         return best_hyps
 
     def forward(
-        self, x: torch.Tensor,
+        self,
+        x: torch.Tensor,
         scama_mask: torch.Tensor = None,
         pre_acoustic_embeds: torch.Tensor = None,
         maxlenratio: float = 0.0,
         minlenratio: float = 0.0,
         maxlen: int = None,
         minlen: int = 0,
-        cache:dict={},
+        cache: dict = {},
     ) -> List[Hypothesis]:
         """Perform beam search.
 
@@ -859,12 +873,22 @@
             pre_acoustic_embeds_cur = None
             if pre_acoustic_embeds is not None:
                 b, t, d = pre_acoustic_embeds.size()
-                pad = torch.zeros((b, 1, d), dtype=pre_acoustic_embeds.dtype).to(device=pre_acoustic_embeds.device)
+                pad = torch.zeros((b, 1, d), dtype=pre_acoustic_embeds.dtype).to(
+                    device=pre_acoustic_embeds.device
+                )
                 pre_acoustic_embeds = torch.cat((pre_acoustic_embeds, pad), dim=1)
                 token_id_slice = min(i, t)
-                pre_acoustic_embeds_cur = pre_acoustic_embeds[:, token_id_slice:token_id_slice+1, :]
+                pre_acoustic_embeds_cur = pre_acoustic_embeds[
+                    :, token_id_slice : token_id_slice + 1, :
+                ]
 
-            best = self.search(running_hyps, x, x_mask=mask_enc, pre_acoustic_embeds=pre_acoustic_embeds_cur, cache=cache["decoder"])
+            best = self.search(
+                running_hyps,
+                x,
+                x_mask=mask_enc,
+                pre_acoustic_embeds=pre_acoustic_embeds_cur,
+                cache=cache["decoder"],
+            )
             # post process of one iteration
             running_hyps = self.post_process(i, maxlen, maxlenratio, best, ended_hyps)
             # end detection
@@ -881,8 +905,7 @@
         # check the number of hypotheses reaching to eos
         if len(nbest_hyps) == 0:
             logging.warning(
-                "there is no N-best results, perform recognition "
-                "again with smaller minlenratio."
+                "there is no N-best results, perform recognition " "again with smaller minlenratio."
             )
             return (
                 []
@@ -896,17 +919,13 @@
             logging.debug("nbest: y: {}, yseq: {}, score: {}".format(x.yseq, yseq, x.score))
         best = nbest_hyps[0]
         for k, v in best.scores.items():
-            logging.info(
-                f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}"
-            )
+            logging.info(f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}")
         logging.info(f"total log probability: {best.score:.2f}")
         logging.info(f"normalized log probability: {best.score / len(best.yseq):.2f}")
         logging.info(f"total number of ended hypotheses: {len(nbest_hyps)}")
         if self.token_list is not None:
             logging.info(
-                "best hypo: "
-                + "".join([self.token_list[x] for x in best.yseq[1:-1]])
-                + "\n"
+                "best hypo: " + "".join([self.token_list[x] for x in best.yseq[1:-1]]) + "\n"
             )
         return nbest_hyps
 
@@ -934,15 +953,13 @@
         logging.debug(f"the number of running hypotheses: {len(running_hyps)}")
         if self.token_list is not None:
             logging.debug(
-                "best hypo: "
-                + "".join([self.token_list[x] for x in running_hyps[0].yseq[1:]])
+                "best hypo: " + "".join([self.token_list[x] for x in running_hyps[0].yseq[1:]])
             )
         # add eos in the final loop to avoid that there are no ended hyps
         if i == maxlen - 1:
             logging.info("adding <eos> in the last position in the loop")
             running_hyps = [
-                h._replace(yseq=self.append_token(h.yseq, self.eos))
-                for h in running_hyps
+                h._replace(yseq=self.append_token(h.yseq, self.eos)) for h in running_hyps
             ]
 
         # add ended hypotheses to a final list, and removed them from current hypotheses

--
Gitblit v1.9.1