From 84b4a01979ecc483096cccf5185dbe5e56946217 Mon Sep 17 00:00:00 2001
From: haoneng.lhn <haoneng.lhn@alibaba-inc.com>
Date: 星期五, 26 五月 2023 11:43:27 +0800
Subject: [PATCH] add paraformer online infer and finetune
---
funasr/models/e2e_asr_paraformer.py | 96 ++++++++++++++++++++++++++++++++++++++++++++++-
funasr/models/decoder/sanm_decoder.py | 6 +-
funasr/bin/asr_inference_launch.py | 4 +
3 files changed, 99 insertions(+), 7 deletions(-)
diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py
index dbbb3ed..f5296f6 100644
--- a/funasr/bin/asr_inference_launch.py
+++ b/funasr/bin/asr_inference_launch.py
@@ -1618,6 +1618,8 @@
return inference_uniasr(**kwargs)
elif mode == "paraformer":
return inference_paraformer(**kwargs)
+ elif mode == "paraformer_online":
+ return inference_paraformer(**kwargs)
elif mode == "paraformer_streaming":
return inference_paraformer_online(**kwargs)
elif mode.startswith("paraformer_vad"):
@@ -1900,4 +1902,4 @@
if __name__ == "__main__":
- main()
\ No newline at end of file
+ main()
diff --git a/funasr/models/decoder/sanm_decoder.py b/funasr/models/decoder/sanm_decoder.py
index 508eb73..ed920bf 100644
--- a/funasr/models/decoder/sanm_decoder.py
+++ b/funasr/models/decoder/sanm_decoder.py
@@ -956,13 +956,13 @@
"""
tgt = ys_in_pad
tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
+
+ memory = hs_pad
+ memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
if chunk_mask is not None:
memory_mask = memory_mask * chunk_mask
if tgt_mask.size(1) != memory_mask.size(1):
memory_mask = torch.cat((memory_mask, memory_mask[:, -2:-1, :]), dim=1)
-
- memory = hs_pad
- memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
x = tgt
x, tgt_mask, memory, memory_mask, _ = self.decoders(
diff --git a/funasr/models/e2e_asr_paraformer.py b/funasr/models/e2e_asr_paraformer.py
index 54db971..09af2cd 100644
--- a/funasr/models/e2e_asr_paraformer.py
+++ b/funasr/models/e2e_asr_paraformer.py
@@ -279,7 +279,7 @@
def encode(
self, speech: torch.Tensor, speech_lengths: torch.Tensor, ind: int = 0,
- ) -> Tuple[Tuple[Any, Optional[Any]], Any]:
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
@@ -649,7 +649,35 @@
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
assert 0.0 <= interctc_weight < 1.0, interctc_weight
- super().__init__()
+ super().__init__(
+ vocab_size=vocab_size,
+ token_list=token_list,
+ frontend=frontend,
+ specaug=specaug,
+ normalize=normalize,
+ preencoder=preencoder,
+ encoder=encoder,
+ postencoder=postencoder,
+ decoder=decoder,
+ ctc=ctc,
+ ctc_weight=ctc_weight,
+ interctc_weight=interctc_weight,
+ ignore_id=ignore_id,
+ blank_id=blank_id,
+ sos=sos,
+ eos=eos,
+ lsm_weight=lsm_weight,
+ length_normalized_loss=length_normalized_loss,
+ report_cer=report_cer,
+ report_wer=report_wer,
+ sym_space=sym_space,
+ sym_blank=sym_blank,
+ extract_feats_in_collect_stats=extract_feats_in_collect_stats,
+ predictor=predictor,
+ predictor_weight=predictor_weight,
+ predictor_bias=predictor_bias,
+ sampling_ratio=sampling_ratio,
+ )
# note that eos is the same as sos (equivalent ID)
self.blank_id = blank_id
self.sos = vocab_size - 1 if sos is None else sos
@@ -705,6 +733,7 @@
self.sampling_ratio = sampling_ratio
self.criterion_pre = mae_loss(normalize_length=length_normalized_loss)
self.step_cur = 0
+ self.scama_mask = None
if hasattr(self.encoder, "overlap_chunk_cls") and self.encoder.overlap_chunk_cls is not None:
from funasr.modules.streaming_utils.chunk_utilis import build_scama_mask_for_cross_attention_decoder
self.build_scama_mask_for_cross_attention_decoder_fn = build_scama_mask_for_cross_attention_decoder
@@ -859,7 +888,7 @@
# Pre-encoder, e.g. used for raw input data
if self.preencoder is not None:
feats, feats_lengths = self.preencoder(feats, feats_lengths)
-
+
# 4. Forward encoder
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
@@ -1111,12 +1140,73 @@
return sematic_embeds * tgt_mask, decoder_out * tgt_mask, pre_loss_att
+ def calc_predictor(self, encoder_out, encoder_out_lens):
+
+ encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+ encoder_out.device)
+ mask_chunk_predictor = None
+ if self.encoder.overlap_chunk_cls is not None:
+ mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(None,
+ device=encoder_out.device,
+ batch_size=encoder_out.size(
+ 0))
+ mask_shfit_chunk = self.encoder.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device,
+ batch_size=encoder_out.size(0))
+ encoder_out = encoder_out * mask_shfit_chunk
+ pre_acoustic_embeds, pre_token_length, pre_alphas, pre_peak_index = self.predictor(encoder_out,
+ None,
+ encoder_out_mask,
+ ignore_id=self.ignore_id,
+ mask_chunk_predictor=mask_chunk_predictor,
+ target_label_length=None,
+ )
+ predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(pre_alphas[:, :-1],
+ encoder_out_lens)
+
+ scama_mask = None
+ if self.encoder.overlap_chunk_cls is not None and self.decoder_attention_chunk_type == 'chunk':
+ encoder_chunk_size = self.encoder.overlap_chunk_cls.chunk_size_pad_shift_cur
+ attention_chunk_center_bias = 0
+ attention_chunk_size = encoder_chunk_size
+ decoder_att_look_back_factor = self.encoder.overlap_chunk_cls.decoder_att_look_back_factor_cur
+ mask_shift_att_chunk_decoder = self.encoder.overlap_chunk_cls.\
+ get_mask_shift_att_chunk_decoder(None,
+ device=encoder_out.device,
+ batch_size=encoder_out.size(0)
+ )
+ scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn(
+ predictor_alignments=predictor_alignments,
+ encoder_sequence_length=encoder_out_lens,
+ chunk_size=1,
+ encoder_chunk_size=encoder_chunk_size,
+ attention_chunk_center_bias=attention_chunk_center_bias,
+ attention_chunk_size=attention_chunk_size,
+ attention_chunk_type=self.decoder_attention_chunk_type,
+ step=None,
+ predictor_mask_chunk_hopping=mask_chunk_predictor,
+ decoder_att_look_back_factor=decoder_att_look_back_factor,
+ mask_shift_att_chunk_decoder=mask_shift_att_chunk_decoder,
+ target_length=None,
+ is_training=self.training,
+ )
+ self.scama_mask = scama_mask
+
+ return pre_acoustic_embeds, pre_token_length, pre_alphas, pre_peak_index
+
def calc_predictor_chunk(self, encoder_out, cache=None):
pre_acoustic_embeds, pre_token_length = \
self.predictor.forward_chunk(encoder_out, cache["encoder"])
return pre_acoustic_embeds, pre_token_length
+ def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens):
+ decoder_outs = self.decoder(
+ encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, self.scama_mask
+ )
+ decoder_out = decoder_outs[0]
+ decoder_out = torch.log_softmax(decoder_out, dim=-1)
+ return decoder_out, ys_pad_lens
+
def cal_decoder_with_predictor_chunk(self, encoder_out, sematic_embeds, cache=None):
decoder_outs = self.decoder.forward_chunk(
encoder_out, sematic_embeds, cache["decoder"]
--
Gitblit v1.9.1