From 33a9e08dc9b65abc3f3e18d14253f95c79e0f749 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 08 五月 2024 19:20:43 +0800
Subject: [PATCH] dynamic batch
---
funasr/models/scama/beam_search.py | 121 +++++++++++++++++++++++-----------------
1 files changed, 69 insertions(+), 52 deletions(-)
diff --git a/funasr/models/scama/beam_search.py b/funasr/models/scama/beam_search.py
index b8aa876..259f483 100644
--- a/funasr/models/scama/beam_search.py
+++ b/funasr/models/scama/beam_search.py
@@ -33,7 +33,6 @@
)._asdict()
-
class BeamSearchScama(torch.nn.Module):
"""Beam search implementation."""
@@ -151,7 +150,8 @@
return torch.cat((xs, x))
def score_full(
- self, hyp: Hypothesis,
+ self,
+ hyp: Hypothesis,
x: torch.Tensor,
x_mask: torch.Tensor = None,
pre_acoustic_embeds: torch.Tensor = None,
@@ -173,7 +173,9 @@
scores = dict()
states = dict()
for k, d in self.full_scorers.items():
- scores[k], states[k] = d.score(hyp.yseq, hyp.states[k], x, x_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds)
+ scores[k], states[k] = d.score(
+ hyp.yseq, hyp.states[k], x, x_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds
+ )
return scores, states
def score_partial(
@@ -283,7 +285,8 @@
return new_states
def search(
- self, running_hyps: List[Hypothesis],
+ self,
+ running_hyps: List[Hypothesis],
x: torch.Tensor,
x_mask: torch.Tensor = None,
pre_acoustic_embeds: torch.Tensor = None,
@@ -303,7 +306,9 @@
for hyp in running_hyps:
# scoring
weighted_scores = torch.zeros(self.n_vocab, dtype=x.dtype, device=x.device)
- scores, states = self.score_full(hyp, x, x_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds)
+ scores, states = self.score_full(
+ hyp, x, x_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds
+ )
for k in self.full_scorers:
weighted_scores += self.weights[k] * scores[k]
# partial scoring
@@ -327,9 +332,7 @@
Hypothesis(
score=weighted_scores[j],
yseq=self.append_token(hyp.yseq, j),
- scores=self.merge_scores(
- hyp.scores, scores, j, part_scores, part_j
- ),
+ scores=self.merge_scores(hyp.scores, scores, j, part_scores, part_j),
states=self.merge_states(states, part_states, part_j),
)
)
@@ -341,7 +344,8 @@
return best_hyps
def forward(
- self, x: torch.Tensor,
+ self,
+ x: torch.Tensor,
scama_mask: torch.Tensor = None,
pre_acoustic_embeds: torch.Tensor = None,
maxlenratio: float = 0.0,
@@ -386,20 +390,26 @@
mask_enc = None
if scama_mask is not None:
token_num_predictor = scama_mask.size(1)
- token_id_slice = min(i, token_num_predictor-1)
- mask_enc = scama_mask[:, token_id_slice:token_id_slice+1, :]
+ token_id_slice = min(i, token_num_predictor - 1)
+ mask_enc = scama_mask[:, token_id_slice : token_id_slice + 1, :]
# if mask_enc.size(1) == 0:
# mask_enc = scama_mask[:, -2:-1, :]
# # mask_enc = torch.zeros_like(mask_enc)
pre_acoustic_embeds_cur = None
if pre_acoustic_embeds is not None:
b, t, d = pre_acoustic_embeds.size()
- pad = torch.zeros((b, 1, d), dtype=pre_acoustic_embeds.dtype).to(device=pre_acoustic_embeds.device)
+ pad = torch.zeros((b, 1, d), dtype=pre_acoustic_embeds.dtype).to(
+ device=pre_acoustic_embeds.device
+ )
pre_acoustic_embeds = torch.cat((pre_acoustic_embeds, pad), dim=1)
token_id_slice = min(i, t)
- pre_acoustic_embeds_cur = pre_acoustic_embeds[:, token_id_slice:token_id_slice+1, :]
+ pre_acoustic_embeds_cur = pre_acoustic_embeds[
+ :, token_id_slice : token_id_slice + 1, :
+ ]
- best = self.search(running_hyps, x, x_mask=mask_enc, pre_acoustic_embeds=pre_acoustic_embeds_cur)
+ best = self.search(
+ running_hyps, x, x_mask=mask_enc, pre_acoustic_embeds=pre_acoustic_embeds_cur
+ )
# post process of one iteration
running_hyps = self.post_process(i, maxlen, maxlenratio, best, ended_hyps)
# end detection
@@ -416,8 +426,7 @@
# check the number of hypotheses reaching to eos
if len(nbest_hyps) == 0:
logging.warning(
- "there is no N-best results, perform recognition "
- "again with smaller minlenratio."
+ "there is no N-best results, perform recognition " "again with smaller minlenratio."
)
return (
[]
@@ -431,17 +440,13 @@
logging.debug("nbest: y: {}, yseq: {}, score: {}".format(x.yseq, yseq, x.score))
best = nbest_hyps[0]
for k, v in best.scores.items():
- logging.info(
- f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}"
- )
+ logging.info(f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}")
logging.info(f"total log probability: {best.score:.2f}")
logging.info(f"normalized log probability: {best.score / len(best.yseq):.2f}")
logging.info(f"total number of ended hypotheses: {len(nbest_hyps)}")
if self.token_list is not None:
logging.info(
- "best hypo: "
- + "".join([self.token_list[x] for x in best.yseq[1:-1]])
- + "\n"
+ "best hypo: " + "".join([self.token_list[x] for x in best.yseq[1:-1]]) + "\n"
)
return nbest_hyps
@@ -469,15 +474,13 @@
logging.debug(f"the number of running hypotheses: {len(running_hyps)}")
if self.token_list is not None:
logging.debug(
- "best hypo: "
- + "".join([self.token_list[x] for x in running_hyps[0].yseq[1:]])
+ "best hypo: " + "".join([self.token_list[x] for x in running_hyps[0].yseq[1:]])
)
# add eos in the final loop to avoid that there are no ended hyps
if i == maxlen - 1:
logging.info("adding <eos> in the last position in the loop")
running_hyps = [
- h._replace(yseq=self.append_token(h.yseq, self.eos))
- for h in running_hyps
+ h._replace(yseq=self.append_token(h.yseq, self.eos)) for h in running_hyps
]
# add ended hypotheses to a final list, and removed them from current hypotheses
@@ -494,6 +497,7 @@
else:
remained_hyps.append(hyp)
return remained_hyps
+
class BeamSearchScamaStreaming(torch.nn.Module):
"""Beam search implementation."""
@@ -612,11 +616,12 @@
return torch.cat((xs, x))
def score_full(
- self, hyp: Hypothesis,
+ self,
+ hyp: Hypothesis,
x: torch.Tensor,
x_mask: torch.Tensor = None,
pre_acoustic_embeds: torch.Tensor = None,
- cache: dict={},
+ cache: dict = {},
) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
"""Score new hypothesis by `self.full_scorers`.
@@ -635,7 +640,14 @@
scores = dict()
states = dict()
for k, d in self.full_scorers.items():
- scores[k], states[k] = d.score(hyp.yseq, hyp.states[k], x, x_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds, cache=cache)
+ scores[k], states[k] = d.score(
+ hyp.yseq,
+ hyp.states[k],
+ x,
+ x_mask=x_mask,
+ pre_acoustic_embeds=pre_acoustic_embeds,
+ cache=cache,
+ )
return scores, states
def score_partial(
@@ -745,11 +757,12 @@
return new_states
def search(
- self, running_hyps: List[Hypothesis],
+ self,
+ running_hyps: List[Hypothesis],
x: torch.Tensor,
x_mask: torch.Tensor = None,
pre_acoustic_embeds: torch.Tensor = None,
- cache: dict={},
+ cache: dict = {},
) -> List[Hypothesis]:
"""Search new tokens for running hypotheses and encoded speech x.
@@ -766,7 +779,9 @@
for hyp in running_hyps:
# scoring
weighted_scores = torch.zeros(self.n_vocab, dtype=x.dtype, device=x.device)
- scores, states = self.score_full(hyp, x, x_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds, cache=cache)
+ scores, states = self.score_full(
+ hyp, x, x_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds, cache=cache
+ )
for k in self.full_scorers:
weighted_scores += self.weights[k] * scores[k]
# partial scoring
@@ -790,9 +805,7 @@
Hypothesis(
score=weighted_scores[j],
yseq=self.append_token(hyp.yseq, j),
- scores=self.merge_scores(
- hyp.scores, scores, j, part_scores, part_j
- ),
+ scores=self.merge_scores(hyp.scores, scores, j, part_scores, part_j),
states=self.merge_states(states, part_states, part_j),
)
)
@@ -804,14 +817,15 @@
return best_hyps
def forward(
- self, x: torch.Tensor,
+ self,
+ x: torch.Tensor,
scama_mask: torch.Tensor = None,
pre_acoustic_embeds: torch.Tensor = None,
maxlenratio: float = 0.0,
minlenratio: float = 0.0,
maxlen: int = None,
minlen: int = 0,
- cache:dict={},
+ cache: dict = {},
) -> List[Hypothesis]:
"""Perform beam search.
@@ -859,12 +873,22 @@
pre_acoustic_embeds_cur = None
if pre_acoustic_embeds is not None:
b, t, d = pre_acoustic_embeds.size()
- pad = torch.zeros((b, 1, d), dtype=pre_acoustic_embeds.dtype).to(device=pre_acoustic_embeds.device)
+ pad = torch.zeros((b, 1, d), dtype=pre_acoustic_embeds.dtype).to(
+ device=pre_acoustic_embeds.device
+ )
pre_acoustic_embeds = torch.cat((pre_acoustic_embeds, pad), dim=1)
token_id_slice = min(i, t)
- pre_acoustic_embeds_cur = pre_acoustic_embeds[:, token_id_slice:token_id_slice+1, :]
+ pre_acoustic_embeds_cur = pre_acoustic_embeds[
+ :, token_id_slice : token_id_slice + 1, :
+ ]
- best = self.search(running_hyps, x, x_mask=mask_enc, pre_acoustic_embeds=pre_acoustic_embeds_cur, cache=cache["decoder"])
+ best = self.search(
+ running_hyps,
+ x,
+ x_mask=mask_enc,
+ pre_acoustic_embeds=pre_acoustic_embeds_cur,
+ cache=cache["decoder"],
+ )
# post process of one iteration
running_hyps = self.post_process(i, maxlen, maxlenratio, best, ended_hyps)
# end detection
@@ -881,8 +905,7 @@
# check the number of hypotheses reaching to eos
if len(nbest_hyps) == 0:
logging.warning(
- "there is no N-best results, perform recognition "
- "again with smaller minlenratio."
+ "there is no N-best results, perform recognition " "again with smaller minlenratio."
)
return (
[]
@@ -896,17 +919,13 @@
logging.debug("nbest: y: {}, yseq: {}, score: {}".format(x.yseq, yseq, x.score))
best = nbest_hyps[0]
for k, v in best.scores.items():
- logging.info(
- f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}"
- )
+ logging.info(f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}")
logging.info(f"total log probability: {best.score:.2f}")
logging.info(f"normalized log probability: {best.score / len(best.yseq):.2f}")
logging.info(f"total number of ended hypotheses: {len(nbest_hyps)}")
if self.token_list is not None:
logging.info(
- "best hypo: "
- + "".join([self.token_list[x] for x in best.yseq[1:-1]])
- + "\n"
+ "best hypo: " + "".join([self.token_list[x] for x in best.yseq[1:-1]]) + "\n"
)
return nbest_hyps
@@ -934,15 +953,13 @@
logging.debug(f"the number of running hypotheses: {len(running_hyps)}")
if self.token_list is not None:
logging.debug(
- "best hypo: "
- + "".join([self.token_list[x] for x in running_hyps[0].yseq[1:]])
+ "best hypo: " + "".join([self.token_list[x] for x in running_hyps[0].yseq[1:]])
)
# add eos in the final loop to avoid that there are no ended hyps
if i == maxlen - 1:
logging.info("adding <eos> in the last position in the loop")
running_hyps = [
- h._replace(yseq=self.append_token(h.yseq, self.eos))
- for h in running_hyps
+ h._replace(yseq=self.append_token(h.yseq, self.eos)) for h in running_hyps
]
# add ended hypotheses to a final list, and removed them from current hypotheses
--
Gitblit v1.9.1