From 9ba0dbd98bf69c830dfcfde8f109a400cb65e4e5 Mon Sep 17 00:00:00 2001
From: 雾聪 <wucong.lyb@alibaba-inc.com>
Date: 星期五, 29 三月 2024 17:24:59 +0800
Subject: [PATCH] fix func Forward
---
funasr/models/bicif_paraformer/model.py | 130 ++++++++++++++++++++++++++++++++++++-------
1 files changed, 108 insertions(+), 22 deletions(-)
diff --git a/funasr/models/bicif_paraformer/model.py b/funasr/models/bicif_paraformer/model.py
index 01f19c6..9849c8c 100644
--- a/funasr/models/bicif_paraformer/model.py
+++ b/funasr/models/bicif_paraformer/model.py
@@ -23,7 +23,7 @@
from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
-
+from funasr.train_utils.device_funcs import to_device
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
@@ -235,23 +235,23 @@
self.nbest = kwargs.get("nbest", 1)
meta_data = {}
- if isinstance(data_in, torch.Tensor): # fbank
- speech, speech_lengths = data_in, data_lengths
- if len(speech.shape) < 3:
- speech = speech[None, :, :]
- if speech_lengths is None:
- speech_lengths = speech.shape[1]
- else:
- # extract fbank feats
- time1 = time.perf_counter()
- audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000))
- time2 = time.perf_counter()
- meta_data["load_data"] = f"{time2 - time1:0.3f}"
- speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
- frontend=frontend)
- time3 = time.perf_counter()
- meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
- meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
+ # if isinstance(data_in, torch.Tensor): # fbank
+ # speech, speech_lengths = data_in, data_lengths
+ # if len(speech.shape) < 3:
+ # speech = speech[None, :, :]
+ # if speech_lengths is None:
+ # speech_lengths = speech.shape[1]
+ # else:
+ # extract fbank feats
+ time1 = time.perf_counter()
+ audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000))
+ time2 = time.perf_counter()
+ meta_data["load_data"] = f"{time2 - time1:0.3f}"
+ speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
+ frontend=frontend)
+ time3 = time.perf_counter()
+ meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+ meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
speech = speech.to(device=kwargs["device"])
speech_lengths = speech_lengths.to(device=kwargs["device"])
@@ -300,9 +300,11 @@
nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
for nbest_idx, hyp in enumerate(nbest_hyps):
ibest_writer = None
- if ibest_writer is None and kwargs.get("output_dir") is not None:
- writer = DatadirWriter(kwargs.get("output_dir"))
- ibest_writer = writer[f"{nbest_idx + 1}best_recog"]
+ if kwargs.get("output_dir") is not None:
+ if not hasattr(self, "writer"):
+ self.writer = DatadirWriter(kwargs.get("output_dir"))
+ ibest_writer = self.writer[f"{nbest_idx+1}best_recog"]
+
# remove sos/eos and get results
last_pos = -1
if isinstance(hyp.yseq, list):
@@ -339,4 +341,88 @@
result_i = {"key": key[i], "token_int": token_int}
results.append(result_i)
- return results, meta_data
\ No newline at end of file
+ return results, meta_data
+
+ def export(
+ self,
+ max_seq_len=512,
+ **kwargs,
+ ):
+ self.device = kwargs.get("device")
+ is_onnx = kwargs.get("type", "onnx") == "onnx"
+ encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export")
+ self.encoder = encoder_class(self.encoder, onnx=is_onnx)
+
+ predictor_class = tables.predictor_classes.get(kwargs["predictor"] + "Export")
+ self.predictor = predictor_class(self.predictor, onnx=is_onnx)
+
+ decoder_class = tables.decoder_classes.get(kwargs["decoder"] + "Export")
+ self.decoder = decoder_class(self.decoder, onnx=is_onnx)
+
+ from funasr.utils.torch_function import sequence_mask
+
+ self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
+
+
+ self.forward = self.export_forward
+
+ return self
+
+ def export_forward(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ ):
+ # a. To device
+ batch = {"speech": speech, "speech_lengths": speech_lengths}
+ batch = to_device(batch, device=self.device)
+
+ enc, enc_len = self.encoder(**batch)
+ mask = self.make_pad_mask(enc_len)[:, None, :]
+ pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(enc, mask)
+ pre_token_length = pre_token_length.round().type(torch.int32)
+
+ decoder_out, _ = self.decoder(enc, enc_len, pre_acoustic_embeds, pre_token_length)
+ decoder_out = torch.log_softmax(decoder_out, dim=-1)
+
+ # get predicted timestamps
+ us_alphas, us_cif_peak = self.predictor.get_upsample_timestmap(enc, mask, pre_token_length)
+
+ return decoder_out, pre_token_length, us_alphas, us_cif_peak
+
+ def export_dummy_inputs(self):
+ speech = torch.randn(2, 30, 560)
+ speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
+ return (speech, speech_lengths)
+
+ def export_input_names(self):
+ return ['speech', 'speech_lengths']
+
+ def export_output_names(self):
+ return ['logits', 'token_num', 'us_alphas', 'us_cif_peak']
+
+ def export_dynamic_axes(self):
+ return {
+ 'speech': {
+ 0: 'batch_size',
+ 1: 'feats_length'
+ },
+ 'speech_lengths': {
+ 0: 'batch_size',
+ },
+ 'logits': {
+ 0: 'batch_size',
+ 1: 'logits_length'
+ },
+ 'us_alphas': {
+ 0: 'batch_size',
+ 1: 'alphas_length'
+ },
+ 'us_cif_peak': {
+ 0: 'batch_size',
+ 1: 'alphas_length'
+ },
+ }
+
+ def export_name(self, ):
+ return "model.onnx"
--
Gitblit v1.9.1