liugz18
2024-07-18 d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99
funasr/models/transducer/beam_search_transducer.py
@@ -92,21 +92,18 @@
        self.vocab_size = decoder.vocab_size
        assert beam_size <= self.vocab_size, (
            "beam_size (%d) should be smaller than or equal to vocabulary size (%d)."
            % (
                beam_size,
                self.vocab_size,
            )
        assert (
            beam_size <= self.vocab_size
        ), "beam_size (%d) should be smaller than or equal to vocabulary size (%d)." % (
            beam_size,
            self.vocab_size,
        )
        self.beam_size = beam_size
        if search_type == "default":
            self.search_algorithm = self.default_beam_search
        elif search_type == "tsd":
            assert max_sym_exp > 1, "max_sym_exp (%d) should be greater than one." % (
                max_sym_exp
            )
            assert max_sym_exp > 1, "max_sym_exp (%d) should be greater than one." % (max_sym_exp)
            self.max_sym_exp = max_sym_exp
            self.search_algorithm = self.time_sync_decoding
@@ -130,9 +127,7 @@
            self.search_algorithm = self.modified_adaptive_expansion_search
        else:
            raise NotImplementedError(
                "Specified search type (%s) is not supported." % search_type
            )
            raise NotImplementedError("Specified search type (%s) is not supported." % search_type)
        self.use_lm = lm is not None
@@ -244,17 +239,12 @@
        k_expansions = []
        for i, hyp in enumerate(hyps):
            hyp_i = [
                (int(k), hyp.score + float(v))
                for k, v in zip(topk_idx[i], topk_logp[i])
            ]
            hyp_i = [(int(k), hyp.score + float(v)) for k, v in zip(topk_idx[i], topk_logp[i])]
            k_best_exp = max(hyp_i, key=lambda x: x[1])[1]
            k_expansions.append(
                sorted(
                    filter(
                        lambda x: (k_best_exp - self.expansion_gamma) <= x[1], hyp_i
                    ),
                    filter(lambda x: (k_best_exp - self.expansion_gamma) <= x[1], hyp_i),
                    key=lambda x: x[1],
                    reverse=True,
                )
@@ -342,9 +332,7 @@
                if self.use_lm:
                    lm_scores, lm_state = self.lm.score(
                        torch.LongTensor(
                            [self.sos] + max_hyp.yseq[1:], device=self.decoder.device
                        ),
                        torch.LongTensor([self.sos] + max_hyp.yseq[1:], device=self.decoder.device),
                        max_hyp.lm_state,
                        None,
                    )
@@ -376,7 +364,7 @@
                    break
        return kept_hyps
    def align_length_sync_decoding(
        self,
        enc_out: torch.Tensor,