From 3d9f094e9652d4b84894c6fd4eae39a4a753b0f0 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 16 五月 2023 23:48:00 +0800
Subject: [PATCH] train
---
funasr/models/e2e_asr_paraformer.py | 166 +++++++++++++++++++++++++++----------------------------
1 files changed, 81 insertions(+), 85 deletions(-)
diff --git a/funasr/models/e2e_asr_paraformer.py b/funasr/models/e2e_asr_paraformer.py
index 699d85f..00e08b1 100644
--- a/funasr/models/e2e_asr_paraformer.py
+++ b/funasr/models/e2e_asr_paraformer.py
@@ -29,9 +29,8 @@
from funasr.modules.nets_utils import make_pad_mask, pad_list
from funasr.modules.nets_utils import th_accuracy
from funasr.torch_utils.device_funcs import force_gatherable
-from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.base_model import FunASRModel
from funasr.models.predictor.cif import CifPredictorV3
-
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
@@ -42,7 +41,7 @@
yield
-class Paraformer(AbsESPnetModel):
+class Paraformer(FunASRModel):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
@@ -56,9 +55,7 @@
frontend: Optional[AbsFrontend],
specaug: Optional[AbsSpecAug],
normalize: Optional[AbsNormalize],
- preencoder: Optional[AbsPreEncoder],
encoder: AbsEncoder,
- postencoder: Optional[AbsPostEncoder],
decoder: AbsDecoder,
ctc: CTC,
ctc_weight: float = 0.5,
@@ -79,6 +76,8 @@
predictor_bias: int = 0,
sampling_ratio: float = 0.2,
share_embedding: bool = False,
+ preencoder: Optional[AbsPreEncoder] = None,
+ postencoder: Optional[AbsPostEncoder] = None,
):
assert check_argument_types()
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
@@ -153,7 +152,6 @@
text_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
-
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@@ -270,7 +268,6 @@
self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
-
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@@ -368,9 +365,7 @@
ys_pad_lens: torch.Tensor,
) -> torch.Tensor:
"""Compute negative log likelihood(nll) from transformer-decoder
-
Normally, this function is called in batchify_nll.
-
Args:
encoder_out: (Batch, Length, Dim)
encoder_out_lens: (Batch,)
@@ -407,7 +402,6 @@
batch_size: int = 100,
):
"""Compute negative log likelihood(nll) from transformer-decoder
-
To avoid OOM, this fuction seperate the input into batches.
Then call nll for each batch and combine and return results.
Args:
@@ -664,7 +658,10 @@
self, speech: torch.Tensor, speech_lengths: torch.Tensor, cache: dict = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
+<<<<<<< HEAD
+=======
+>>>>>>> 4cd79db451786548d8a100f25c3b03da0eb30f4b
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@@ -712,9 +709,9 @@
def calc_predictor_chunk(self, encoder_out, cache=None):
- pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = \
+ pre_acoustic_embeds, pre_token_length = \
self.predictor.forward_chunk(encoder_out, cache["encoder"])
- return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
+ return pre_acoustic_embeds, pre_token_length
def cal_decoder_with_predictor_chunk(self, encoder_out, sematic_embeds, cache=None):
decoder_outs = self.decoder.forward_chunk(
@@ -738,9 +735,7 @@
frontend: Optional[AbsFrontend],
specaug: Optional[AbsSpecAug],
normalize: Optional[AbsNormalize],
- preencoder: Optional[AbsPreEncoder],
encoder: AbsEncoder,
- postencoder: Optional[AbsPostEncoder],
decoder: AbsDecoder,
ctc: CTC,
ctc_weight: float = 0.5,
@@ -763,6 +758,8 @@
embeds_id: int = 2,
embeds_loss_weight: float = 0.0,
embed_dims: int = 768,
+ preencoder: Optional[AbsPreEncoder] = None,
+ postencoder: Optional[AbsPostEncoder] = None,
):
assert check_argument_types()
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
@@ -894,7 +891,6 @@
embed_lengths: torch.Tensor = None,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
-
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@@ -913,9 +909,9 @@
self.step_cur += 1
# for data-parallel
text = text[:, : text_lengths.max()]
- speech = speech[:, :speech_lengths.max(), :]
+ speech = speech[:, :speech_lengths.max()]
if embed is not None:
- embed = embed[:, :embed_lengths.max(), :]
+ embed = embed[:, :embed_lengths.max()]
# 1. Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
@@ -1003,74 +999,73 @@
class BiCifParaformer(Paraformer):
-
"""
Paraformer model with an extra cif predictor
to conduct accurate timestamp prediction
"""
def __init__(
- self,
- vocab_size: int,
- token_list: Union[Tuple[str, ...], List[str]],
- frontend: Optional[AbsFrontend],
- specaug: Optional[AbsSpecAug],
- normalize: Optional[AbsNormalize],
- preencoder: Optional[AbsPreEncoder],
- encoder: AbsEncoder,
- postencoder: Optional[AbsPostEncoder],
- decoder: AbsDecoder,
- ctc: CTC,
- ctc_weight: float = 0.5,
- interctc_weight: float = 0.0,
- ignore_id: int = -1,
- blank_id: int = 0,
- sos: int = 1,
- eos: int = 2,
- lsm_weight: float = 0.0,
- length_normalized_loss: bool = False,
- report_cer: bool = True,
- report_wer: bool = True,
- sym_space: str = "<space>",
- sym_blank: str = "<blank>",
- extract_feats_in_collect_stats: bool = True,
- predictor = None,
- predictor_weight: float = 0.0,
- predictor_bias: int = 0,
- sampling_ratio: float = 0.2,
+ self,
+ vocab_size: int,
+ token_list: Union[Tuple[str, ...], List[str]],
+ frontend: Optional[AbsFrontend],
+ specaug: Optional[AbsSpecAug],
+ normalize: Optional[AbsNormalize],
+ encoder: AbsEncoder,
+ decoder: AbsDecoder,
+ ctc: CTC,
+ ctc_weight: float = 0.5,
+ interctc_weight: float = 0.0,
+ ignore_id: int = -1,
+ blank_id: int = 0,
+ sos: int = 1,
+ eos: int = 2,
+ lsm_weight: float = 0.0,
+ length_normalized_loss: bool = False,
+ report_cer: bool = True,
+ report_wer: bool = True,
+ sym_space: str = "<space>",
+ sym_blank: str = "<blank>",
+ extract_feats_in_collect_stats: bool = True,
+ predictor=None,
+ predictor_weight: float = 0.0,
+ predictor_bias: int = 0,
+ sampling_ratio: float = 0.2,
+ preencoder: Optional[AbsPreEncoder] = None,
+ postencoder: Optional[AbsPostEncoder] = None,
):
assert check_argument_types()
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
assert 0.0 <= interctc_weight < 1.0, interctc_weight
super().__init__(
- vocab_size=vocab_size,
- token_list=token_list,
- frontend=frontend,
- specaug=specaug,
- normalize=normalize,
- preencoder=preencoder,
- encoder=encoder,
- postencoder=postencoder,
- decoder=decoder,
- ctc=ctc,
- ctc_weight=ctc_weight,
- interctc_weight=interctc_weight,
- ignore_id=ignore_id,
- blank_id=blank_id,
- sos=sos,
- eos=eos,
- lsm_weight=lsm_weight,
- length_normalized_loss=length_normalized_loss,
- report_cer=report_cer,
- report_wer=report_wer,
- sym_space=sym_space,
- sym_blank=sym_blank,
- extract_feats_in_collect_stats=extract_feats_in_collect_stats,
- predictor=predictor,
- predictor_weight=predictor_weight,
- predictor_bias=predictor_bias,
- sampling_ratio=sampling_ratio,
+ vocab_size=vocab_size,
+ token_list=token_list,
+ frontend=frontend,
+ specaug=specaug,
+ normalize=normalize,
+ preencoder=preencoder,
+ encoder=encoder,
+ postencoder=postencoder,
+ decoder=decoder,
+ ctc=ctc,
+ ctc_weight=ctc_weight,
+ interctc_weight=interctc_weight,
+ ignore_id=ignore_id,
+ blank_id=blank_id,
+ sos=sos,
+ eos=eos,
+ lsm_weight=lsm_weight,
+ length_normalized_loss=length_normalized_loss,
+ report_cer=report_cer,
+ report_wer=report_wer,
+ sym_space=sym_space,
+ sym_blank=sym_blank,
+ extract_feats_in_collect_stats=extract_feats_in_collect_stats,
+ predictor=predictor,
+ predictor_weight=predictor_weight,
+ predictor_bias=predictor_bias,
+ sampling_ratio=sampling_ratio,
)
assert isinstance(self.predictor, CifPredictorV3), "BiCifParaformer should use CIFPredictorV3"
@@ -1145,21 +1140,23 @@
cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
return loss_att, acc_att, cer_att, wer_att, loss_pre
-
+
def calc_predictor(self, encoder_out, encoder_out_lens):
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
encoder_out.device)
- pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index, pre_token_length2 = self.predictor(encoder_out, None, encoder_out_mask,
- ignore_id=self.ignore_id)
+ pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index, pre_token_length2 = self.predictor(encoder_out,
+ None,
+ encoder_out_mask,
+ ignore_id=self.ignore_id)
return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
-
+
def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num):
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
encoder_out.device)
ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(encoder_out,
- encoder_out_mask,
- token_num)
+ encoder_out_mask,
+ token_num)
return ds_alphas, ds_cif_peak, us_alphas, us_peaks
def forward(
@@ -1170,7 +1167,6 @@
text_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
-
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@@ -1253,7 +1249,8 @@
elif self.ctc_weight == 1.0:
loss = loss_ctc
else:
- loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight * 0.5
+ loss = self.ctc_weight * loss_ctc + (
+ 1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight * 0.5
# Collect Attn branch stats
stats["loss_att"] = loss_att.detach() if loss_att is not None else None
@@ -1282,9 +1279,7 @@
frontend: Optional[AbsFrontend],
specaug: Optional[AbsSpecAug],
normalize: Optional[AbsNormalize],
- preencoder: Optional[AbsPreEncoder],
encoder: AbsEncoder,
- postencoder: Optional[AbsPostEncoder],
decoder: AbsDecoder,
ctc: CTC,
ctc_weight: float = 0.5,
@@ -1314,6 +1309,8 @@
bias_encoder_type: str = 'lstm',
label_bracket: bool = False,
use_decoder_embedding: bool = False,
+ preencoder: Optional[AbsPreEncoder] = None,
+ postencoder: Optional[AbsPostEncoder] = None,
):
assert check_argument_types()
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
@@ -1377,7 +1374,6 @@
text_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
-
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@@ -1761,4 +1757,4 @@
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
var_dict_tf[name_tf].shape))
- return var_dict_torch_update
+ return var_dict_torch_update
\ No newline at end of file
--
Gitblit v1.9.1