From 15bb198d4af581356917ec4c96322fefbd1e01d3 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 06 五月 2024 21:55:03 +0800
Subject: [PATCH] decoding
---
funasr/models/sense_voice/model.py | 102 ++++++++++++++++++++++++++++++++++++++++++++++-----
1 files changed, 92 insertions(+), 10 deletions(-)
diff --git a/funasr/models/sense_voice/model.py b/funasr/models/sense_voice/model.py
index 82ccc55..d5e4130 100644
--- a/funasr/models/sense_voice/model.py
+++ b/funasr/models/sense_voice/model.py
@@ -396,6 +396,42 @@
return loss_att, acc_att, None, None
+ def init_beam_search(
+ self,
+ **kwargs,
+ ):
+ from .search import BeamSearch
+
+ from funasr.models.transformer.scorers.length_bonus import LengthBonus
+
+ # 1. Build ASR model
+ scorers = {}
+
+ scorers.update(
+ decoder=self.model.decoder,
+ length_bonus=LengthBonus(self.vocab_size),
+ )
+
+ weights = dict(
+ decoder=1.0,
+ ctc=0.0,
+ lm=0.0,
+ ngram=0.0,
+ length_bonus=kwargs.get("penalty", 0.0),
+ )
+ beam_search = BeamSearch(
+ beam_size=kwargs.get("beam_size", 5),
+ weights=weights,
+ scorers=scorers,
+ sos=None,
+ eos=None,
+ vocab_size=self.vocab_size,
+ token_list=None,
+ pre_beam_score_key="full",
+ )
+
+ self.beam_search = beam_search
+
def inference(
self,
data_in,
@@ -407,6 +443,12 @@
):
if kwargs.get("batch_size", 1) > 1:
raise NotImplementedError("batch decoding is not implemented")
+
+ # init beamsearch
+ if not hasattr(self, "beam_search") or self.beam_search is None:
+ logging.info("enable beam_search")
+ self.init_beam_search(**kwargs)
+ self.nbest = kwargs.get("nbest", 1)
if frontend is None and not hasattr(self, "frontend"):
frontend_class = tables.frontend_classes.get("WhisperFrontend")
@@ -456,25 +498,65 @@
task = [task]
task = "".join([f"<|{x}|>" for x in task])
initial_prompt = kwargs.get("initial_prompt", f"<|startoftranscript|>{task}")
- DecodingOptions["initial_prompt"] = initial_prompt
language = DecodingOptions.get("language", None)
language = None if language == "auto" else language
- DecodingOptions["language"] = language
- DecodingOptions["vocab_path"] = kwargs["tokenizer_conf"].get("vocab_path", None)
+ sos = f"{initial_prompt}<|{language}|>" if language is not None else initial_prompt
+ sos_int = tokenizer.encode(sos, allowed_special="all")
+ eos = kwargs.get("model_conf").get("eos")
+ eos_int = tokenizer.encode(eos, allowed_special="all")
+ self.beam_search.sos = sos_int
+ self.beam_search.eos = eos_int[0]
- if "without_timestamps" not in DecodingOptions:
- DecodingOptions["without_timestamps"] = True
+ encoder_out, encoder_out_lens = self.encode(
+ speech[None, :, :].permute(0, 2, 1), speech_lengths
+ )
- options = whisper.DecodingOptions(**DecodingOptions)
+ # c. Passed the encoder result and the beam search
+ nbest_hyps = self.beam_search(
+ x=encoder_out[0],
+ maxlenratio=kwargs.get("maxlenratio", 0.0),
+ minlenratio=kwargs.get("minlenratio", 0.0),
+ )
- result = whisper.decode(self.model, speech, options)
- text = f"{result.text}"
+ nbest_hyps = nbest_hyps[: self.nbest]
+
results = []
- result_i = {"key": key[0], "text": text}
+ b, n, d = encoder_out.size()
+ for i in range(b):
- results.append(result_i)
+ for nbest_idx, hyp in enumerate(nbest_hyps):
+ ibest_writer = None
+ if kwargs.get("output_dir") is not None:
+ if not hasattr(self, "writer"):
+ self.writer = DatadirWriter(kwargs.get("output_dir"))
+ ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"]
+
+ # remove sos/eos and get results
+ last_pos = -1
+ if isinstance(hyp.yseq, list):
+ token_int = hyp.yseq[1:last_pos]
+ else:
+ token_int = hyp.yseq[1:last_pos].tolist()
+
+ # # remove blank symbol id, which is assumed to be 0
+ # token_int = list(
+ # filter(
+ # lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int
+ # )
+ # )
+
+ # Change integer-ids to tokens
+ # token = tokenizer.ids2tokens(token_int)
+ text = tokenizer.decode(token_int)
+
+ result_i = {"key": key[i], "text": text}
+ results.append(result_i)
+
+ if ibest_writer is not None:
+ # ibest_writer["token"][key[i]] = " ".join(token)
+ ibest_writer["text"][key[i]] = text
return results, meta_data
--
Gitblit v1.9.1