zhifu gao
2024-01-21 37d7764ecf0e8cc1a14f59b8b9cd1c914da8b005
funasr/models/scama/model.py
@@ -436,7 +436,10 @@
    def init_beam_search(self,
                         **kwargs,
                         ):
        from funasr.models.scama.beam_search import BeamSearchScama
        from funasr.models.scama.beam_search import BeamSearchScamaStreaming
        from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
        from funasr.models.transformer.scorers.length_bonus import LengthBonus
    
@@ -460,13 +463,14 @@
        scorers["ngram"] = ngram
    
        weights = dict(
            decoder=1.0 - kwargs.get("decoding_ctc_weight"),
            decoder=1.0 - kwargs.get("decoding_ctc_weight", 0.0),
            ctc=kwargs.get("decoding_ctc_weight", 0.0),
            lm=kwargs.get("lm_weight", 0.0),
            ngram=kwargs.get("ngram_weight", 0.0),
            length_bonus=kwargs.get("penalty", 0.0),
        )
        beam_search = BeamSearchScama(
        beam_search = BeamSearchScamaStreaming(
            beam_size=kwargs.get("beam_size", 2),
            weights=weights,
            scorers=scorers,
@@ -499,7 +503,11 @@
                                                          is_final=kwargs.get("is_final", False))
        if isinstance(encoder_out, tuple):
            encoder_out = encoder_out[0]
        if "running_hyps" not in cache:
            running_hyps = self.beam_search.init_hyp(encoder_out)
            cache["running_hyps"] = running_hyps
        # predictor
        predictor_outs = self.calc_predictor_chunk(encoder_out,
                                                   encoder_out_lens,
@@ -513,47 +521,30 @@
        if torch.max(pre_token_length) < 1:
            return []
        decoder_outs = self.cal_decoder_with_predictor_chunk(encoder_out,
                                                             encoder_out_lens,
                                                             pre_acoustic_embeds,
                                                             pre_token_length,
                                                             cache=cache
                                                             )
        decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
        maxlen = minlen = pre_token_length
        if kwargs.get("is_final", False):
            maxlen += kwargs.get("token_num_relax", 5)
            minlen = max(0, minlen - kwargs.get("token_num_relax", 5))
        # c. Passed the encoder result and the beam search
        nbest_hyps = self.beam_search(
            x=encoder_out[0], scama_mask=None, pre_acoustic_embeds=pre_acoustic_embeds, maxlen=int(maxlen), minlen=int(minlen), cache=cache,
        )
        cache["running_hyps"] = nbest_hyps
        nbest_hyps = nbest_hyps[: self.nbest]
        results = []
        b, n, d = decoder_out.size()
        if isinstance(key[0], (list, tuple)):
            key = key[0]
        for i in range(b):
            x = encoder_out[i, :encoder_out_lens[i], :]
            am_scores = decoder_out[i, :pre_token_length[i], :]
            if self.beam_search is not None:
                nbest_hyps = self.beam_search(
                    x=x, am_scores=am_scores, maxlenratio=kwargs.get("maxlenratio", 0.0),
                    minlenratio=kwargs.get("minlenratio", 0.0)
                )
                nbest_hyps = nbest_hyps[: self.nbest]
        for hyp in nbest_hyps:
            # assert isinstance(hyp, (Hypothesis)), type(hyp)
            # remove sos/eos and get results
            last_pos = -1
            if isinstance(hyp.yseq, list):
                token_int = hyp.yseq[1:last_pos]
            else:
                yseq = am_scores.argmax(dim=-1)
                score = am_scores.max(dim=-1)[0]
                score = torch.sum(score, dim=-1)
                # pad with mask tokens to ensure compatibility with sos/eos tokens
                yseq = torch.tensor(
                    [self.sos] + yseq.tolist() + [self.eos], device=yseq.device
                )
                nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
            for nbest_idx, hyp in enumerate(nbest_hyps):
                # remove sos/eos and get results
                last_pos = -1
                if isinstance(hyp.yseq, list):
                    token_int = hyp.yseq[1:last_pos]
                else:
                    token_int = hyp.yseq[1:last_pos].tolist()
                token_int = hyp.yseq[1:last_pos].tolist()
                # remove blank symbol id, which is assumed to be 0
                token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
            
@@ -568,6 +559,8 @@
        return results
    def init_cache(self, cache: dict = {}, **kwargs):
        device = kwargs.get("device", "cuda")
        chunk_size = kwargs.get("chunk_size", [0, 10, 5])
        encoder_chunk_look_back = kwargs.get("encoder_chunk_look_back", 0)
        decoder_chunk_look_back = kwargs.get("decoder_chunk_look_back", 0)
@@ -575,10 +568,11 @@
    
        enc_output_size = kwargs["encoder_conf"]["output_size"]
        feats_dims = kwargs["frontend_conf"]["n_mels"] * kwargs["frontend_conf"]["lfr_m"]
        cache_encoder = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)),
                         "cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size,
        cache_encoder = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)).to(device=device),
                         "cif_alphas": torch.zeros((batch_size, 1)).to(device=device), "chunk_size": chunk_size,
                         "encoder_chunk_look_back": encoder_chunk_look_back, "last_chunk": False, "opt": None,
                         "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)),
                         "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)).to(device=device),
                         "tail_chunk": False}
        cache["encoder"] = cache_encoder
    
@@ -586,8 +580,10 @@
                         "chunk_size": chunk_size}
        cache["decoder"] = cache_decoder
        cache["frontend"] = {}
        cache["prev_samples"] = torch.empty(0)
        cache["prev_samples"] = torch.empty(0).to(device=device)
        return cache
    def inference(self,
@@ -603,7 +599,10 @@
        # init beamsearch
        is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
        is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
        if self.beam_search is None and (is_use_lm or is_use_ctc):
        if self.beam_search is None:
            logging.info("enable beam_search")
            self.init_beam_search(**kwargs)
            self.nbest = kwargs.get("nbest", 1)