From df5f263e5fe3d7961b1aeb3589012400a9905a8f Mon Sep 17 00:00:00 2001
From: 嘉渊 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期一, 24 四月 2023 16:17:41 +0800
Subject: [PATCH] update
---
funasr/models/e2e_uni_asr.py | 28 ++++++++++++----------------
1 files changed, 12 insertions(+), 16 deletions(-)
diff --git a/funasr/models/e2e_uni_asr.py b/funasr/models/e2e_uni_asr.py
index 03fbca9..0c53389 100644
--- a/funasr/models/e2e_uni_asr.py
+++ b/funasr/models/e2e_uni_asr.py
@@ -17,15 +17,12 @@
LabelSmoothingLoss, # noqa: H301
)
from funasr.models.ctc import CTC
-from funasr.models.decoder.abs_decoder import AbsDecoder
from funasr.models.encoder.abs_encoder import AbsEncoder
-from funasr.models.frontend.abs_frontend import AbsFrontend
+from funasr.models.decoder.abs_decoder import AbsDecoder
from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
-from funasr.models.specaug.abs_specaug import AbsSpecAug
-from funasr.layers.abs_normalize import AbsNormalize
from funasr.torch_utils.device_funcs import force_gatherable
-from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.base_model import FunASRModel
from funasr.modules.streaming_utils.chunk_utilis import sequence_mask
from funasr.models.predictor.cif import mae_loss
@@ -38,7 +35,7 @@
yield
-class UniASR(AbsESPnetModel):
+class UniASR(FunASRModel):
"""
Author: Speech Lab, Alibaba Group, China
"""
@@ -47,11 +44,11 @@
self,
vocab_size: int,
token_list: Union[Tuple[str, ...], List[str]],
- frontend: Optional[AbsFrontend],
- specaug: Optional[AbsSpecAug],
- normalize: Optional[AbsNormalize],
+ frontend: Optional[torch.nn.Module],
+ specaug: Optional[torch.nn.Module],
+ normalize: Optional[torch.nn.Module],
preencoder: Optional[AbsPreEncoder],
- encoder: AbsEncoder,
+ encoder: torch.nn.Module,
postencoder: Optional[AbsPostEncoder],
decoder: AbsDecoder,
ctc: CTC,
@@ -198,16 +195,15 @@
# for data-parallel
text = text[:, : text_lengths.max()]
- speech = speech[:, :speech_lengths.max(), :]
+ speech = speech[:, :speech_lengths.max()]
ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind)
- speech_raw = speech.clone().to(speech.device)
# 1. Encoder
if self.enable_maas_finetune:
with torch.no_grad():
- encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
+ speech_raw, encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
else:
- encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
+ speech_raw, encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
intermediate_outs = None
if isinstance(encoder_out, tuple):
@@ -486,7 +482,7 @@
# 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
if self.normalize is not None:
feats, feats_lengths = self.normalize(feats, feats_lengths)
-
+ speech_raw = feats.clone().to(feats.device)
# Pre-encoder, e.g. used for raw input data
if self.preencoder is not None:
feats, feats_lengths = self.preencoder(feats, feats_lengths)
@@ -523,7 +519,7 @@
if intermediate_outs is not None:
return (encoder_out, intermediate_outs), encoder_out_lens
- return encoder_out, encoder_out_lens
+ return speech_raw, encoder_out, encoder_out_lens
def encode2(
self,
--
Gitblit v1.9.1