From 37d7764ecf0e8cc1a14f59b8b9cd1c914da8b005 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期日, 21 一月 2024 21:06:52 +0800
Subject: [PATCH] Funasr1.0 (#1277)
---
funasr/models/scama/model.py | 97 ++++++++++++++++++++++++------------------------
1 files changed, 48 insertions(+), 49 deletions(-)
diff --git a/funasr/models/scama/model.py b/funasr/models/scama/model.py
index aec6fe3..32e16bd 100644
--- a/funasr/models/scama/model.py
+++ b/funasr/models/scama/model.py
@@ -436,7 +436,10 @@
def init_beam_search(self,
**kwargs,
):
- from funasr.models.scama.beam_search import BeamSearchScama
+
+ from funasr.models.scama.beam_search import BeamSearchScamaStreaming
+
+
from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
from funasr.models.transformer.scorers.length_bonus import LengthBonus
@@ -460,13 +463,14 @@
scorers["ngram"] = ngram
weights = dict(
- decoder=1.0 - kwargs.get("decoding_ctc_weight"),
+ decoder=1.0 - kwargs.get("decoding_ctc_weight", 0.0),
ctc=kwargs.get("decoding_ctc_weight", 0.0),
lm=kwargs.get("lm_weight", 0.0),
ngram=kwargs.get("ngram_weight", 0.0),
length_bonus=kwargs.get("penalty", 0.0),
)
- beam_search = BeamSearchScama(
+
+ beam_search = BeamSearchScamaStreaming(
beam_size=kwargs.get("beam_size", 2),
weights=weights,
scorers=scorers,
@@ -499,7 +503,11 @@
is_final=kwargs.get("is_final", False))
if isinstance(encoder_out, tuple):
encoder_out = encoder_out[0]
-
+ if "running_hyps" not in cache:
+ running_hyps = self.beam_search.init_hyp(encoder_out)
+ cache["running_hyps"] = running_hyps
+
+
# predictor
predictor_outs = self.calc_predictor_chunk(encoder_out,
encoder_out_lens,
@@ -513,47 +521,30 @@
if torch.max(pre_token_length) < 1:
return []
- decoder_outs = self.cal_decoder_with_predictor_chunk(encoder_out,
- encoder_out_lens,
- pre_acoustic_embeds,
- pre_token_length,
- cache=cache
- )
- decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
-
+ maxlen = minlen = pre_token_length
+ if kwargs.get("is_final", False):
+ maxlen += kwargs.get("token_num_relax", 5)
+ minlen = max(0, minlen - kwargs.get("token_num_relax", 5))
+ # c. Passed the encoder result and the beam search
+ nbest_hyps = self.beam_search(
+ x=encoder_out[0], scama_mask=None, pre_acoustic_embeds=pre_acoustic_embeds, maxlen=int(maxlen), minlen=int(minlen), cache=cache,
+ )
+
+ cache["running_hyps"] = nbest_hyps
+ nbest_hyps = nbest_hyps[: self.nbest]
+
results = []
- b, n, d = decoder_out.size()
- if isinstance(key[0], (list, tuple)):
- key = key[0]
- for i in range(b):
- x = encoder_out[i, :encoder_out_lens[i], :]
- am_scores = decoder_out[i, :pre_token_length[i], :]
- if self.beam_search is not None:
- nbest_hyps = self.beam_search(
- x=x, am_scores=am_scores, maxlenratio=kwargs.get("maxlenratio", 0.0),
- minlenratio=kwargs.get("minlenratio", 0.0)
- )
-
- nbest_hyps = nbest_hyps[: self.nbest]
+ for hyp in nbest_hyps:
+ # assert isinstance(hyp, (Hypothesis)), type(hyp)
+
+ # remove sos/eos and get results
+ last_pos = -1
+ if isinstance(hyp.yseq, list):
+ token_int = hyp.yseq[1:last_pos]
else:
-
- yseq = am_scores.argmax(dim=-1)
- score = am_scores.max(dim=-1)[0]
- score = torch.sum(score, dim=-1)
- # pad with mask tokens to ensure compatibility with sos/eos tokens
- yseq = torch.tensor(
- [self.sos] + yseq.tolist() + [self.eos], device=yseq.device
- )
- nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
- for nbest_idx, hyp in enumerate(nbest_hyps):
-
- # remove sos/eos and get results
- last_pos = -1
- if isinstance(hyp.yseq, list):
- token_int = hyp.yseq[1:last_pos]
- else:
- token_int = hyp.yseq[1:last_pos].tolist()
-
+ token_int = hyp.yseq[1:last_pos].tolist()
+
+
# remove blank symbol id, which is assumed to be 0
token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
@@ -568,6 +559,8 @@
return results
def init_cache(self, cache: dict = {}, **kwargs):
+ device = kwargs.get("device", "cuda")
+
chunk_size = kwargs.get("chunk_size", [0, 10, 5])
encoder_chunk_look_back = kwargs.get("encoder_chunk_look_back", 0)
decoder_chunk_look_back = kwargs.get("decoder_chunk_look_back", 0)
@@ -575,10 +568,11 @@
enc_output_size = kwargs["encoder_conf"]["output_size"]
feats_dims = kwargs["frontend_conf"]["n_mels"] * kwargs["frontend_conf"]["lfr_m"]
- cache_encoder = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)),
- "cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size,
+
+ cache_encoder = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)).to(device=device),
+ "cif_alphas": torch.zeros((batch_size, 1)).to(device=device), "chunk_size": chunk_size,
"encoder_chunk_look_back": encoder_chunk_look_back, "last_chunk": False, "opt": None,
- "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)),
+ "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)).to(device=device),
"tail_chunk": False}
cache["encoder"] = cache_encoder
@@ -586,8 +580,10 @@
"chunk_size": chunk_size}
cache["decoder"] = cache_decoder
cache["frontend"] = {}
- cache["prev_samples"] = torch.empty(0)
-
+
+
+ cache["prev_samples"] = torch.empty(0).to(device=device)
+
return cache
def inference(self,
@@ -603,7 +599,10 @@
# init beamsearch
is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
- if self.beam_search is None and (is_use_lm or is_use_ctc):
+
+ if self.beam_search is None:
+
+
logging.info("enable beam_search")
self.init_beam_search(**kwargs)
self.nbest = kwargs.get("nbest", 1)
--
Gitblit v1.9.1