From 3d9f094e9652d4b84894c6fd4eae39a4a753b0f0 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 16 五月 2023 23:48:00 +0800
Subject: [PATCH] train
---
funasr/models/e2e_asr_paraformer.py | 34 ++++++++++++++++++----------------
1 files changed, 18 insertions(+), 16 deletions(-)
diff --git a/funasr/models/e2e_asr_paraformer.py b/funasr/models/e2e_asr_paraformer.py
index 9d4f106..00e08b1 100644
--- a/funasr/models/e2e_asr_paraformer.py
+++ b/funasr/models/e2e_asr_paraformer.py
@@ -55,9 +55,7 @@
frontend: Optional[AbsFrontend],
specaug: Optional[AbsSpecAug],
normalize: Optional[AbsNormalize],
- preencoder: Optional[AbsPreEncoder],
encoder: AbsEncoder,
- postencoder: Optional[AbsPostEncoder],
decoder: AbsDecoder,
ctc: CTC,
ctc_weight: float = 0.5,
@@ -78,6 +76,8 @@
predictor_bias: int = 0,
sampling_ratio: float = 0.2,
share_embedding: bool = False,
+ preencoder: Optional[AbsPreEncoder] = None,
+ postencoder: Optional[AbsPostEncoder] = None,
):
assert check_argument_types()
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
@@ -326,9 +326,8 @@
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
encoder_out.device)
- pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(encoder_out, None,
- encoder_out_mask,
- ignore_id=self.ignore_id)
+ pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(encoder_out, None, encoder_out_mask,
+ ignore_id=self.ignore_id)
return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens):
@@ -659,6 +658,10 @@
self, speech: torch.Tensor, speech_lengths: torch.Tensor, cache: dict = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
+<<<<<<< HEAD
+=======
+
+>>>>>>> 4cd79db451786548d8a100f25c3b03da0eb30f4b
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@@ -732,9 +735,7 @@
frontend: Optional[AbsFrontend],
specaug: Optional[AbsSpecAug],
normalize: Optional[AbsNormalize],
- preencoder: Optional[AbsPreEncoder],
encoder: AbsEncoder,
- postencoder: Optional[AbsPostEncoder],
decoder: AbsDecoder,
ctc: CTC,
ctc_weight: float = 0.5,
@@ -757,6 +758,8 @@
embeds_id: int = 2,
embeds_loss_weight: float = 0.0,
embed_dims: int = 768,
+ preencoder: Optional[AbsPreEncoder] = None,
+ postencoder: Optional[AbsPostEncoder] = None,
):
assert check_argument_types()
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
@@ -906,9 +909,9 @@
self.step_cur += 1
# for data-parallel
text = text[:, : text_lengths.max()]
- speech = speech[:, :speech_lengths.max(), :]
+ speech = speech[:, :speech_lengths.max()]
if embed is not None:
- embed = embed[:, :embed_lengths.max(), :]
+ embed = embed[:, :embed_lengths.max()]
# 1. Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
@@ -1008,9 +1011,7 @@
frontend: Optional[AbsFrontend],
specaug: Optional[AbsSpecAug],
normalize: Optional[AbsNormalize],
- preencoder: Optional[AbsPreEncoder],
encoder: AbsEncoder,
- postencoder: Optional[AbsPostEncoder],
decoder: AbsDecoder,
ctc: CTC,
ctc_weight: float = 0.5,
@@ -1030,6 +1031,8 @@
predictor_weight: float = 0.0,
predictor_bias: int = 0,
sampling_ratio: float = 0.2,
+ preencoder: Optional[AbsPreEncoder] = None,
+ postencoder: Optional[AbsPostEncoder] = None,
):
assert check_argument_types()
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
@@ -1097,9 +1100,8 @@
if self.predictor_bias == 1:
_, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
ys_pad_lens = ys_pad_lens + self.predictor_bias
- pre_acoustic_embeds, pre_token_length, _, pre_peak_index, _ = self.predictor(encoder_out, ys_pad,
- encoder_out_mask,
- ignore_id=self.ignore_id)
+ pre_acoustic_embeds, pre_token_length, _, pre_peak_index, _ = self.predictor(encoder_out, ys_pad, encoder_out_mask,
+ ignore_id=self.ignore_id)
# 0. sampler
decoder_out_1st = None
@@ -1277,9 +1279,7 @@
frontend: Optional[AbsFrontend],
specaug: Optional[AbsSpecAug],
normalize: Optional[AbsNormalize],
- preencoder: Optional[AbsPreEncoder],
encoder: AbsEncoder,
- postencoder: Optional[AbsPostEncoder],
decoder: AbsDecoder,
ctc: CTC,
ctc_weight: float = 0.5,
@@ -1309,6 +1309,8 @@
bias_encoder_type: str = 'lstm',
label_bracket: bool = False,
use_decoder_embedding: bool = False,
+ preencoder: Optional[AbsPreEncoder] = None,
+ postencoder: Optional[AbsPostEncoder] = None,
):
assert check_argument_types()
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
--
Gitblit v1.9.1