From 37d7764ecf0e8cc1a14f59b8b9cd1c914da8b005 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期日, 21 一月 2024 21:06:52 +0800
Subject: [PATCH] Funasr1.0 (#1277)

---
 funasr/models/scama/model.py |   97 ++++++++++++++++++++++++------------------------
 1 files changed, 48 insertions(+), 49 deletions(-)

diff --git a/funasr/models/scama/model.py b/funasr/models/scama/model.py
index aec6fe3..32e16bd 100644
--- a/funasr/models/scama/model.py
+++ b/funasr/models/scama/model.py
@@ -436,7 +436,10 @@
     def init_beam_search(self,
                          **kwargs,
                          ):
-        from funasr.models.scama.beam_search import BeamSearchScama
+
+        from funasr.models.scama.beam_search import BeamSearchScamaStreaming
+
+
         from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
         from funasr.models.transformer.scorers.length_bonus import LengthBonus
     
@@ -460,13 +463,14 @@
         scorers["ngram"] = ngram
     
         weights = dict(
-            decoder=1.0 - kwargs.get("decoding_ctc_weight"),
+            decoder=1.0 - kwargs.get("decoding_ctc_weight", 0.0),
             ctc=kwargs.get("decoding_ctc_weight", 0.0),
             lm=kwargs.get("lm_weight", 0.0),
             ngram=kwargs.get("ngram_weight", 0.0),
             length_bonus=kwargs.get("penalty", 0.0),
         )
-        beam_search = BeamSearchScama(
+        
+        beam_search = BeamSearchScamaStreaming(
             beam_size=kwargs.get("beam_size", 2),
             weights=weights,
             scorers=scorers,
@@ -499,7 +503,11 @@
                                                           is_final=kwargs.get("is_final", False))
         if isinstance(encoder_out, tuple):
             encoder_out = encoder_out[0]
-
+        if "running_hyps" not in cache:
+            running_hyps = self.beam_search.init_hyp(encoder_out)
+            cache["running_hyps"] = running_hyps
+       
+       
         # predictor
         predictor_outs = self.calc_predictor_chunk(encoder_out,
                                                    encoder_out_lens,
@@ -513,47 +521,30 @@
 
         if torch.max(pre_token_length) < 1:
             return []
-        decoder_outs = self.cal_decoder_with_predictor_chunk(encoder_out,
-                                                             encoder_out_lens,
-                                                             pre_acoustic_embeds,
-                                                             pre_token_length,
-                                                             cache=cache
-                                                             )
-        decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
-    
+        maxlen = minlen = pre_token_length
+        if kwargs.get("is_final", False):
+            maxlen += kwargs.get("token_num_relax", 5)
+            minlen = max(0, minlen - kwargs.get("token_num_relax", 5))
+        # c. Passed the encoder result and the beam search
+        nbest_hyps = self.beam_search(
+            x=encoder_out[0], scama_mask=None, pre_acoustic_embeds=pre_acoustic_embeds, maxlen=int(maxlen), minlen=int(minlen), cache=cache,
+        )
+
+        cache["running_hyps"] = nbest_hyps
+        nbest_hyps = nbest_hyps[: self.nbest]
+
         results = []
-        b, n, d = decoder_out.size()
-        if isinstance(key[0], (list, tuple)):
-            key = key[0]
-        for i in range(b):
-            x = encoder_out[i, :encoder_out_lens[i], :]
-            am_scores = decoder_out[i, :pre_token_length[i], :]
-            if self.beam_search is not None:
-                nbest_hyps = self.beam_search(
-                    x=x, am_scores=am_scores, maxlenratio=kwargs.get("maxlenratio", 0.0),
-                    minlenratio=kwargs.get("minlenratio", 0.0)
-                )
-            
-                nbest_hyps = nbest_hyps[: self.nbest]
+        for hyp in nbest_hyps:
+            # assert isinstance(hyp, (Hypothesis)), type(hyp)
+
+            # remove sos/eos and get results
+            last_pos = -1
+            if isinstance(hyp.yseq, list):
+                token_int = hyp.yseq[1:last_pos]
             else:
-            
-                yseq = am_scores.argmax(dim=-1)
-                score = am_scores.max(dim=-1)[0]
-                score = torch.sum(score, dim=-1)
-                # pad with mask tokens to ensure compatibility with sos/eos tokens
-                yseq = torch.tensor(
-                    [self.sos] + yseq.tolist() + [self.eos], device=yseq.device
-                )
-                nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
-            for nbest_idx, hyp in enumerate(nbest_hyps):
-            
-                # remove sos/eos and get results
-                last_pos = -1
-                if isinstance(hyp.yseq, list):
-                    token_int = hyp.yseq[1:last_pos]
-                else:
-                    token_int = hyp.yseq[1:last_pos].tolist()
-            
+                token_int = hyp.yseq[1:last_pos].tolist()
+
+
                 # remove blank symbol id, which is assumed to be 0
                 token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
             
@@ -568,6 +559,8 @@
         return results
 
     def init_cache(self, cache: dict = {}, **kwargs):
+        device = kwargs.get("device", "cuda")
+        
         chunk_size = kwargs.get("chunk_size", [0, 10, 5])
         encoder_chunk_look_back = kwargs.get("encoder_chunk_look_back", 0)
         decoder_chunk_look_back = kwargs.get("decoder_chunk_look_back", 0)
@@ -575,10 +568,11 @@
     
         enc_output_size = kwargs["encoder_conf"]["output_size"]
         feats_dims = kwargs["frontend_conf"]["n_mels"] * kwargs["frontend_conf"]["lfr_m"]
-        cache_encoder = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)),
-                         "cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size,
+
+        cache_encoder = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)).to(device=device),
+                         "cif_alphas": torch.zeros((batch_size, 1)).to(device=device), "chunk_size": chunk_size,
                          "encoder_chunk_look_back": encoder_chunk_look_back, "last_chunk": False, "opt": None,
-                         "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)),
+                         "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)).to(device=device),
                          "tail_chunk": False}
         cache["encoder"] = cache_encoder
     
@@ -586,8 +580,10 @@
                          "chunk_size": chunk_size}
         cache["decoder"] = cache_decoder
         cache["frontend"] = {}
-        cache["prev_samples"] = torch.empty(0)
-    
+
+
+        cache["prev_samples"] = torch.empty(0).to(device=device)
+
         return cache
 
     def inference(self,
@@ -603,7 +599,10 @@
         # init beamsearch
         is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
         is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
-        if self.beam_search is None and (is_use_lm or is_use_ctc):
+
+        if self.beam_search is None:
+
+
             logging.info("enable beam_search")
             self.init_beam_search(**kwargs)
             self.nbest = kwargs.get("nbest", 1)

--
Gitblit v1.9.1