From 47088b8d1ebe42b6c376236c19184ef4f440cc0d Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 11 一月 2024 00:09:36 +0800
Subject: [PATCH] funasr1.0 paraformer_streaming
---
runtime/python/onnxruntime/setup.py | 2
funasr/models/paraformer_streaming/model.py | 153 ++++++++++++++++++++++++++++++++++++++++-----------
2 files changed, 121 insertions(+), 34 deletions(-)
diff --git a/funasr/models/paraformer_streaming/model.py b/funasr/models/paraformer_streaming/model.py
index 498d363..304c0f7 100644
--- a/funasr/models/paraformer_streaming/model.py
+++ b/funasr/models/paraformer_streaming/model.py
@@ -375,7 +375,7 @@
return pre_acoustic_embeds, pre_token_length, pre_alphas, pre_peak_index
- def calc_predictor_chunk(self, encoder_out, cache=None):
+ def calc_predictor_chunk(self, encoder_out, encoder_out_lens, cache=None):
pre_acoustic_embeds, pre_token_length = \
self.predictor.forward_chunk(encoder_out, cache["encoder"])
@@ -389,48 +389,72 @@
decoder_out = torch.log_softmax(decoder_out, dim=-1)
return decoder_out, ys_pad_lens
- def cal_decoder_with_predictor_chunk(self, encoder_out, sematic_embeds, cache=None):
+ def cal_decoder_with_predictor_chunk(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, cache=None):
decoder_outs = self.decoder.forward_chunk(
encoder_out, sematic_embeds, cache["decoder"]
)
decoder_out = decoder_outs
decoder_out = torch.log_softmax(decoder_out, dim=-1)
- return decoder_out
+ return decoder_out, ys_pad_lens
+
+ def init_cache(self, cache: dict = {}, **kwargs):
+ chunk_size = kwargs.get("chunk_size", [0, 10, 5])
+ encoder_chunk_look_back = kwargs.get("encoder_chunk_look_back", 0)
+ decoder_chunk_look_back = kwargs.get("decoder_chunk_look_back", 0)
+ batch_size = 1
- def generate(self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- tokenizer=None,
- **kwargs,
- ):
+ enc_output_size = kwargs["encoder_conf"]["output_size"]
+ feats_dims = kwargs["frontend_conf"]["n_mels"] * kwargs["frontend_conf"]["lfr_m"]
+ cache_encoder = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)),
+ "cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size,
+ "encoder_chunk_look_back": encoder_chunk_look_back, "last_chunk": False, "opt": None,
+ "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)),
+ "tail_chunk": False}
+ cache["encoder"] = cache_encoder
- is_use_ctc = kwargs.get("ctc_weight", 0.0) > 0.00001 and self.ctc != None
- print(is_use_ctc)
- is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
+ cache_decoder = {"decode_fsmn": None, "decoder_chunk_look_back": decoder_chunk_look_back, "opt": None,
+ "chunk_size": chunk_size}
+ cache["decoder"] = cache_decoder
+ cache["frontend"] = {}
+ cache["prev_samples"] = []
- if self.beam_search is None and (is_use_lm or is_use_ctc):
- logging.info("enable beam_search")
- self.init_beam_search(speech, speech_lengths, **kwargs)
- self.nbest = kwargs.get("nbest", 1)
+ return cache
+
+ def generate_chunk(self,
+ speech,
+ speech_lengths=None,
+ key: list = None,
+ tokenizer=None,
+ frontend=None,
+ **kwargs,
+ ):
+ cache = kwargs.get("cache", {})
+ speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"])
- # Forward Encoder
- encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+ # Encoder
+ encoder_out, encoder_out_lens = self.encode_chunk(speech, speech_lengths, cache=cache)
if isinstance(encoder_out, tuple):
encoder_out = encoder_out[0]
# predictor
- predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens)
+ predictor_outs = self.calc_predictor_chunk(encoder_out, encoder_out_lens, cache=cache)
pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
predictor_outs[2], predictor_outs[3]
pre_token_length = pre_token_length.round().long()
if torch.max(pre_token_length) < 1:
return []
- decoder_outs = self.cal_decoder_with_predictor(encoder_out, encoder_out_lens, pre_acoustic_embeds,
- pre_token_length)
+ decoder_outs = self.cal_decoder_with_predictor_chunk(encoder_out,
+ encoder_out_lens,
+ pre_acoustic_embeds,
+ pre_token_length,
+ cache=cache
+ )
decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
-
+
results = []
b, n, d = decoder_out.size()
+ if isinstance(key[0], (list, tuple)):
+ key = key[0]
for i in range(b):
x = encoder_out[i, :encoder_out_lens[i], :]
am_scores = decoder_out[i, :pre_token_length[i], :]
@@ -451,9 +475,11 @@
[self.sos] + yseq.tolist() + [self.eos], device=yseq.device
)
nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
- for hyp in nbest_hyps:
- assert isinstance(hyp, (Hypothesis)), type(hyp)
-
+ for nbest_idx, hyp in enumerate(nbest_hyps):
+ ibest_writer = None
+ if ibest_writer is None and kwargs.get("output_dir") is not None:
+ writer = DatadirWriter(kwargs.get("output_dir"))
+ ibest_writer = writer[f"{nbest_idx + 1}best_recog"]
# remove sos/eos and get results
last_pos = -1
if isinstance(hyp.yseq, list):
@@ -462,15 +488,76 @@
token_int = hyp.yseq[1:last_pos].tolist()
# remove blank symbol id, which is assumed to be 0
- token_int = list(filter(lambda x: x != 0 and x != 2, token_int))
+ token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
- # Change integer-ids to tokens
- token = tokenizer.ids2tokens(token_int)
- text = tokenizer.tokens2text(token)
-
- timestamp = []
-
- results.append((text, token, timestamp))
+ if tokenizer is not None:
+ # Change integer-ids to tokens
+ token = tokenizer.ids2tokens(token_int)
+ text = tokenizer.tokens2text(token)
+
+ text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
+
+ result_i = {"key": key[i], "text": text_postprocessed}
+
+ if ibest_writer is not None:
+ ibest_writer["token"][key[i]] = " ".join(token)
+ # ibest_writer["text"][key[i]] = text
+ ibest_writer["text"][key[i]] = text_postprocessed
+ else:
+ result_i = {"key": key[i], "token_int": token_int}
+ results.append(result_i)
return results
+
+ def generate(self,
+ data_in,
+ data_lengths=None,
+ key: list = None,
+ tokenizer=None,
+ frontend=None,
+ **kwargs,
+ ):
+
+ # init beamsearch
+ is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
+ is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
+ if self.beam_search is None and (is_use_lm or is_use_ctc):
+ logging.info("enable beam_search")
+ self.init_beam_search(**kwargs)
+ self.nbest = kwargs.get("nbest", 1)
+
+ cache = kwargs.get("cache", {})
+ if len(cache) == 0:
+ self.init_cache(cache, **kwargs)
+
+ meta_data = {}
+ chunk_size = kwargs.get("chunk_size", [0, 10, 5])
+ chunk_stride_samples = chunk_size[1] * 960 # 600ms
+
+ time1 = time.perf_counter()
+ audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000),
+ data_type=kwargs.get("data_type", "sound"),
+ tokenizer=tokenizer)
+ time2 = time.perf_counter()
+ meta_data["load_data"] = f"{time2 - time1:0.3f}"
+ assert len(audio_sample_list) == 1, "batch_size must be set 1"
+
+ audio_sample = cache["prev_samples"] + audio_sample_list[0]
+
+ n = len(audio_sample) // chunk_stride_samples
+ m = len(audio_sample) % chunk_stride_samples
+ for i in range(n):
+ audio_sample_i = audio_sample[i*chunk_stride_samples:(i+1)*chunk_stride_samples]
+
+ # extract fbank feats
+ speech, speech_lengths = extract_fbank([audio_sample_i], data_type=kwargs.get("data_type", "sound"),
+ frontend=frontend, cache=cache["frontend"])
+ time3 = time.perf_counter()
+ meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+ meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
+
+ result_i = self.generate_chunk(speech, speech_lengths, **kwargs)
+
+ cache["prev_samples"] = audio_sample[:-m]
+
diff --git a/runtime/python/onnxruntime/setup.py b/runtime/python/onnxruntime/setup.py
index b8dc3e1..3c128ee 100644
--- a/runtime/python/onnxruntime/setup.py
+++ b/runtime/python/onnxruntime/setup.py
@@ -13,7 +13,7 @@
MODULE_NAME = 'funasr_onnx'
-VERSION_NUM = '0.2.4'
+VERSION_NUM = '0.2.5'
setuptools.setup(
name=MODULE_NAME,
--
Gitblit v1.9.1