From 54931dd4e1a099d7d6f144c4e12e5453deb3aa26 Mon Sep 17 00:00:00 2001
From: 雾聪 <wucong.lyb@alibaba-inc.com>
Date: 星期三, 28 六月 2023 10:41:57 +0800
Subject: [PATCH] Merge branch 'main' of https://github.com/alibaba-damo-academy/FunASR into main
---
funasr/models/e2e_asr_paraformer.py | 98 +++++++++++++++++++++++++++++++++++++++++++++++--
1 files changed, 94 insertions(+), 4 deletions(-)
diff --git a/funasr/models/e2e_asr_paraformer.py b/funasr/models/e2e_asr_paraformer.py
index 54db971..686038e 100644
--- a/funasr/models/e2e_asr_paraformer.py
+++ b/funasr/models/e2e_asr_paraformer.py
@@ -242,7 +242,7 @@
loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
if self.use_1st_decoder_loss and pre_loss_att is not None:
- loss = loss + pre_loss_att
+ loss = loss + (1 - self.ctc_weight) * pre_loss_att
# Collect Attn branch stats
stats["loss_att"] = loss_att.detach() if loss_att is not None else None
@@ -279,7 +279,7 @@
def encode(
self, speech: torch.Tensor, speech_lengths: torch.Tensor, ind: int = 0,
- ) -> Tuple[Tuple[Any, Optional[Any]], Any]:
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
@@ -649,7 +649,35 @@
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
assert 0.0 <= interctc_weight < 1.0, interctc_weight
- super().__init__()
+ super().__init__(
+ vocab_size=vocab_size,
+ token_list=token_list,
+ frontend=frontend,
+ specaug=specaug,
+ normalize=normalize,
+ preencoder=preencoder,
+ encoder=encoder,
+ postencoder=postencoder,
+ decoder=decoder,
+ ctc=ctc,
+ ctc_weight=ctc_weight,
+ interctc_weight=interctc_weight,
+ ignore_id=ignore_id,
+ blank_id=blank_id,
+ sos=sos,
+ eos=eos,
+ lsm_weight=lsm_weight,
+ length_normalized_loss=length_normalized_loss,
+ report_cer=report_cer,
+ report_wer=report_wer,
+ sym_space=sym_space,
+ sym_blank=sym_blank,
+ extract_feats_in_collect_stats=extract_feats_in_collect_stats,
+ predictor=predictor,
+ predictor_weight=predictor_weight,
+ predictor_bias=predictor_bias,
+ sampling_ratio=sampling_ratio,
+ )
# note that eos is the same as sos (equivalent ID)
self.blank_id = blank_id
self.sos = vocab_size - 1 if sos is None else sos
@@ -705,6 +733,7 @@
self.sampling_ratio = sampling_ratio
self.criterion_pre = mae_loss(normalize_length=length_normalized_loss)
self.step_cur = 0
+ self.scama_mask = None
if hasattr(self.encoder, "overlap_chunk_cls") and self.encoder.overlap_chunk_cls is not None:
from funasr.modules.streaming_utils.chunk_utilis import build_scama_mask_for_cross_attention_decoder
self.build_scama_mask_for_cross_attention_decoder_fn = build_scama_mask_for_cross_attention_decoder
@@ -859,7 +888,7 @@
# Pre-encoder, e.g. used for raw input data
if self.preencoder is not None:
feats, feats_lengths = self.preencoder(feats, feats_lengths)
-
+
# 4. Forward encoder
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
@@ -1111,12 +1140,73 @@
return sematic_embeds * tgt_mask, decoder_out * tgt_mask, pre_loss_att
+ def calc_predictor(self, encoder_out, encoder_out_lens):
+
+ encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+ encoder_out.device)
+ mask_chunk_predictor = None
+ if self.encoder.overlap_chunk_cls is not None:
+ mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(None,
+ device=encoder_out.device,
+ batch_size=encoder_out.size(
+ 0))
+ mask_shfit_chunk = self.encoder.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device,
+ batch_size=encoder_out.size(0))
+ encoder_out = encoder_out * mask_shfit_chunk
+ pre_acoustic_embeds, pre_token_length, pre_alphas, pre_peak_index = self.predictor(encoder_out,
+ None,
+ encoder_out_mask,
+ ignore_id=self.ignore_id,
+ mask_chunk_predictor=mask_chunk_predictor,
+ target_label_length=None,
+ )
+ predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(pre_alphas,
+ encoder_out_lens+1 if self.predictor.tail_threshold > 0.0 else encoder_out_lens)
+
+ scama_mask = None
+ if self.encoder.overlap_chunk_cls is not None and self.decoder_attention_chunk_type == 'chunk':
+ encoder_chunk_size = self.encoder.overlap_chunk_cls.chunk_size_pad_shift_cur
+ attention_chunk_center_bias = 0
+ attention_chunk_size = encoder_chunk_size
+ decoder_att_look_back_factor = self.encoder.overlap_chunk_cls.decoder_att_look_back_factor_cur
+ mask_shift_att_chunk_decoder = self.encoder.overlap_chunk_cls.\
+ get_mask_shift_att_chunk_decoder(None,
+ device=encoder_out.device,
+ batch_size=encoder_out.size(0)
+ )
+ scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn(
+ predictor_alignments=predictor_alignments,
+ encoder_sequence_length=encoder_out_lens,
+ chunk_size=1,
+ encoder_chunk_size=encoder_chunk_size,
+ attention_chunk_center_bias=attention_chunk_center_bias,
+ attention_chunk_size=attention_chunk_size,
+ attention_chunk_type=self.decoder_attention_chunk_type,
+ step=None,
+ predictor_mask_chunk_hopping=mask_chunk_predictor,
+ decoder_att_look_back_factor=decoder_att_look_back_factor,
+ mask_shift_att_chunk_decoder=mask_shift_att_chunk_decoder,
+ target_length=None,
+ is_training=self.training,
+ )
+ self.scama_mask = scama_mask
+
+ return pre_acoustic_embeds, pre_token_length, pre_alphas, pre_peak_index
+
def calc_predictor_chunk(self, encoder_out, cache=None):
pre_acoustic_embeds, pre_token_length = \
self.predictor.forward_chunk(encoder_out, cache["encoder"])
return pre_acoustic_embeds, pre_token_length
+ def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens):
+ decoder_outs = self.decoder(
+ encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, self.scama_mask
+ )
+ decoder_out = decoder_outs[0]
+ decoder_out = torch.log_softmax(decoder_out, dim=-1)
+ return decoder_out, ys_pad_lens
+
def cal_decoder_with_predictor_chunk(self, encoder_out, sematic_embeds, cache=None):
decoder_outs = self.decoder.forward_chunk(
encoder_out, sematic_embeds, cache["decoder"]
--
Gitblit v1.9.1