From cfe577f16fef9fb5b0a48f07d4f9e232799cc9d4 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 08 五月 2024 00:03:52 +0800
Subject: [PATCH] decoding key
---
funasr/models/sense_voice/model.py | 41 +++++++++++++++++++++++++++++++++++++++++
1 files changed, 41 insertions(+), 0 deletions(-)
diff --git a/funasr/models/sense_voice/model.py b/funasr/models/sense_voice/model.py
index bcaaca3..0230638 100644
--- a/funasr/models/sense_voice/model.py
+++ b/funasr/models/sense_voice/model.py
@@ -802,6 +802,16 @@
data_type=kwargs.get("data_type", "sound"),
tokenizer=tokenizer,
)
+
+ if (
+ isinstance(kwargs.get("data_type", None), (list, tuple))
+ and len(kwargs.get("data_type", [])) > 1
+ ):
+ audio_sample_list, text_token_int_list = audio_sample_list
+ text_token_int = text_token_int_list[0]
+ else:
+ text_token_int = None
+
time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}"
speech, speech_lengths = extract_fbank(
@@ -837,6 +847,37 @@
speech[None, :, :].permute(0, 2, 1), speech_lengths
)
+ if text_token_int is not None:
+ i = 0
+ results = []
+ ibest_writer = None
+ if kwargs.get("output_dir") is not None:
+ if not hasattr(self, "writer"):
+ self.writer = DatadirWriter(kwargs.get("output_dir"))
+ ibest_writer = self.writer[f"1best_recog"]
+
+ # 1. Forward decoder
+ ys_pad = torch.tensor(sos_int + text_token_int, dtype=torch.int64).to(kwargs["device"])[
+ None, :
+ ]
+ ys_pad_lens = torch.tensor([len(sos_int + text_token_int)], dtype=torch.int64).to(
+ kwargs["device"]
+ )[None, :]
+ decoder_out = self.model.decoder(
+ x=ys_pad, xa=encoder_out, hlens=encoder_out_lens, ys_in_lens=ys_pad_lens
+ )
+
+ token_int = decoder_out.argmax(-1)[0, :].tolist()
+ text = tokenizer.decode(token_int)
+
+ result_i = {"key": key[i], "text": text}
+ results.append(result_i)
+
+ if ibest_writer is not None:
+ # ibest_writer["token"][key[i]] = " ".join(token)
+ ibest_writer["text"][key[i]] = text
+ return results, meta_data
+
# c. Passed the encoder result and the beam search
nbest_hyps = self.beam_search(
x=encoder_out[0],
--
Gitblit v1.9.1