From 607073619cedf2c114e1589aa6d5953d171f33bf Mon Sep 17 00:00:00 2001
From: 嘉渊 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期四, 27 四月 2023 19:27:49 +0800
Subject: [PATCH] update

---
 funasr/models/e2e_diar_sond.py      |   28 +
 funasr/models/e2e_tp.py             |   26 
 funasr/models/e2e_asr.py            |   21 
 funasr/models/e2e_asr_paraformer.py |  511 +++++++++++++++++++++++++--------
 funasr/models/e2e_asr_mfcca.py      |  148 +++++----
 funasr/models/e2e_vad.py            |   45 ++
 funasr/models/e2e_sv.py             |   30 +
 funasr/models/e2e_uni_asr.py        |   22 
 funasr/models/data2vec.py           |   14 
 funasr/models/e2e_diar_eend_ola.py  |    3 
 10 files changed, 579 insertions(+), 269 deletions(-)

diff --git a/funasr/models/data2vec.py b/funasr/models/data2vec.py
index 380c137..e5bd640 100644
--- a/funasr/models/data2vec.py
+++ b/funasr/models/data2vec.py
@@ -12,7 +12,11 @@
 import torch
 from typeguard import check_argument_types
 
+from funasr.layers.abs_normalize import AbsNormalize
+from funasr.models.encoder.abs_encoder import AbsEncoder
+from funasr.models.frontend.abs_frontend import AbsFrontend
 from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
+from funasr.models.specaug.abs_specaug import AbsSpecAug
 from funasr.torch_utils.device_funcs import force_gatherable
 from funasr.models.base_model import FunASRModel
 
@@ -30,11 +34,11 @@
 
     def __init__(
             self,
-            frontend: Optional[torch.nn.Module],
-            specaug: Optional[torch.nn.Module],
-            normalize: Optional[torch.nn.Module],
+            frontend: Optional[AbsFrontend],
+            specaug: Optional[AbsSpecAug],
+            normalize: Optional[AbsNormalize],
             preencoder: Optional[AbsPreEncoder],
-            encoder: torch.nn.Module,
+            encoder: AbsEncoder,
     ):
         assert check_argument_types()
 
@@ -53,7 +57,6 @@
             speech_lengths: torch.Tensor,
     ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
         """Frontend + Encoder + Calc loss
-
         Args:
             speech: (Batch, Length, ...)
             speech_lengths: (Batch, )
@@ -102,7 +105,6 @@
             speech_lengths: torch.Tensor,
     ):
         """Frontend + Encoder.
-
         Args:
             speech: (Batch, Length, ...)
             speech_lengths: (Batch, )
diff --git a/funasr/models/e2e_asr.py b/funasr/models/e2e_asr.py
index 779d703..8410ede 100644
--- a/funasr/models/e2e_asr.py
+++ b/funasr/models/e2e_asr.py
@@ -13,18 +13,22 @@
 import torch
 from typeguard import check_argument_types
 
+from funasr.layers.abs_normalize import AbsNormalize
 from funasr.losses.label_smoothing_loss import (
     LabelSmoothingLoss,  # noqa: H301
 )
 from funasr.models.ctc import CTC
-from funasr.models.frontend.abs_frontend import AbsFrontend
-from funasr.models.encoder.abs_encoder import AbsEncoder
 from funasr.models.decoder.abs_decoder import AbsDecoder
-from funasr.models.base_model import FunASRModel
+from funasr.models.encoder.abs_encoder import AbsEncoder
+from funasr.models.frontend.abs_frontend import AbsFrontend
+from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
+from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
+from funasr.models.specaug.abs_specaug import AbsSpecAug
 from funasr.modules.add_sos_eos import add_sos_eos
 from funasr.modules.e2e_asr_common import ErrorCalculator
 from funasr.modules.nets_utils import th_accuracy
 from funasr.torch_utils.device_funcs import force_gatherable
+from funasr.models.base_model import FunASRModel
 
 if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
     from torch.cuda.amp import autocast
@@ -43,9 +47,11 @@
             vocab_size: int,
             token_list: Union[Tuple[str, ...], List[str]],
             frontend: Optional[AbsFrontend],
-            specaug: Optional[torch.nn.Module],
-            normalize: Optional[torch.nn.Module],
+            specaug: Optional[AbsSpecAug],
+            normalize: Optional[AbsNormalize],
+            preencoder: Optional[AbsPreEncoder],
             encoder: AbsEncoder,
+            postencoder: Optional[AbsPostEncoder],
             decoder: AbsDecoder,
             ctc: CTC,
             ctc_weight: float = 0.5,
@@ -127,7 +133,6 @@
             text_lengths: torch.Tensor,
     ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
         """Frontend + Encoder + Decoder + Calc loss
-
         Args:
             speech: (Batch, Length, ...)
             speech_lengths: (Batch, )
@@ -243,7 +248,6 @@
             self, speech: torch.Tensor, speech_lengths: torch.Tensor
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         """Frontend + Encoder. Note that this method is used by asr_inference.py
-
         Args:
             speech: (Batch, Length, ...)
             speech_lengths: (Batch, )
@@ -325,9 +329,7 @@
             ys_pad_lens: torch.Tensor,
     ) -> torch.Tensor:
         """Compute negative log likelihood(nll) from transformer-decoder
-
         Normally, this function is called in batchify_nll.
-
         Args:
             encoder_out: (Batch, Length, Dim)
             encoder_out_lens: (Batch,)
@@ -364,7 +366,6 @@
             batch_size: int = 100,
     ):
         """Compute negative log likelihood(nll) from transformer-decoder
-
         To avoid OOM, this fuction seperate the input into batches.
         Then call nll for each batch and combine and return results.
         Args:
diff --git a/funasr/models/e2e_asr_mfcca.py b/funasr/models/e2e_asr_mfcca.py
index efdd90d..44679ef 100644
--- a/funasr/models/e2e_asr_mfcca.py
+++ b/funasr/models/e2e_asr_mfcca.py
@@ -17,10 +17,13 @@
 )
 from funasr.models.ctc import CTC
 from funasr.models.decoder.abs_decoder import AbsDecoder
+from funasr.models.encoder.abs_encoder import AbsEncoder
+from funasr.models.frontend.abs_frontend import AbsFrontend
 from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
-from funasr.models.base_model import FunASRModel
+from funasr.models.specaug.abs_specaug import AbsSpecAug
+from funasr.layers.abs_normalize import AbsNormalize
 from funasr.torch_utils.device_funcs import force_gatherable
-
+from funasr.models.base_model import FunASRModel
 
 if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
     from torch.cuda.amp import autocast
@@ -32,30 +35,36 @@
 import pdb
 import random
 import math
+
+
 class MFCCA(FunASRModel):
-    """CTC-attention hybrid Encoder-Decoder model"""
+    """
+    Author: Audio, Speech and Language Processing Group (ASLP@NPU), Northwestern Polytechnical University
+    MFCCA:Multi-Frame Cross-Channel attention for multi-speaker ASR in Multi-party meeting scenario
+    https://arxiv.org/abs/2210.05265
+    """
 
     def __init__(
-        self,
-        vocab_size: int,
-        token_list: Union[Tuple[str, ...], List[str]],
-        frontend: Optional[torch.nn.Module],
-        specaug: Optional[torch.nn.Module],
-        normalize: Optional[torch.nn.Module],
-        preencoder: Optional[AbsPreEncoder],
-        encoder: torch.nn.Module,
-        decoder: AbsDecoder,
-        ctc: CTC,
-        rnnt_decoder: None,
-        ctc_weight: float = 0.5,
-        ignore_id: int = -1,
-        lsm_weight: float = 0.0,
-        mask_ratio: float = 0.0,
-        length_normalized_loss: bool = False,
-        report_cer: bool = True,
-        report_wer: bool = True,
-        sym_space: str = "<space>",
-        sym_blank: str = "<blank>",
+            self,
+            vocab_size: int,
+            token_list: Union[Tuple[str, ...], List[str]],
+            frontend: Optional[AbsFrontend],
+            specaug: Optional[AbsSpecAug],
+            normalize: Optional[AbsNormalize],
+            preencoder: Optional[AbsPreEncoder],
+            encoder: AbsEncoder,
+            decoder: AbsDecoder,
+            ctc: CTC,
+            rnnt_decoder: None,
+            ctc_weight: float = 0.5,
+            ignore_id: int = -1,
+            lsm_weight: float = 0.0,
+            mask_ratio: float = 0.0,
+            length_normalized_loss: bool = False,
+            report_cer: bool = True,
+            report_wer: bool = True,
+            sym_space: str = "<space>",
+            sym_blank: str = "<blank>",
     ):
         assert check_argument_types()
         assert 0.0 <= ctc_weight <= 1.0, ctc_weight
@@ -69,10 +78,9 @@
         self.ignore_id = ignore_id
         self.ctc_weight = ctc_weight
         self.token_list = token_list.copy()
-        
+
         self.mask_ratio = mask_ratio
 
-        
         self.frontend = frontend
         self.specaug = specaug
         self.normalize = normalize
@@ -106,14 +114,13 @@
             self.error_calculator = None
 
     def forward(
-        self,
-        speech: torch.Tensor,
-        speech_lengths: torch.Tensor,
-        text: torch.Tensor,
-        text_lengths: torch.Tensor,
+            self,
+            speech: torch.Tensor,
+            speech_lengths: torch.Tensor,
+            text: torch.Tensor,
+            text_lengths: torch.Tensor,
     ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
         """Frontend + Encoder + Decoder + Calc loss
-
         Args:
             speech: (Batch, Length, ...)
             speech_lengths: (Batch, )
@@ -123,22 +130,22 @@
         assert text_lengths.dim() == 1, text_lengths.shape
         # Check that batch_size is unified
         assert (
-            speech.shape[0]
-            == speech_lengths.shape[0]
-            == text.shape[0]
-            == text_lengths.shape[0]
+                speech.shape[0]
+                == speech_lengths.shape[0]
+                == text.shape[0]
+                == text_lengths.shape[0]
         ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
-        #pdb.set_trace()
-        if(speech.dim()==3 and speech.size(2)==8 and self.mask_ratio !=0):
+        # pdb.set_trace()
+        if (speech.dim() == 3 and speech.size(2) == 8 and self.mask_ratio != 0):
             rate_num = random.random()
-            #rate_num = 0.1
-            if(rate_num<=self.mask_ratio):
-                retain_channel = math.ceil(random.random() *8)
-                if(retain_channel>1):
-                    speech = speech[:,:,torch.randperm(8)[0:retain_channel].sort().values]
+            # rate_num = 0.1
+            if (rate_num <= self.mask_ratio):
+                retain_channel = math.ceil(random.random() * 8)
+                if (retain_channel > 1):
+                    speech = speech[:, :, torch.randperm(8)[0:retain_channel].sort().values]
                 else:
-                    speech = speech[:,:,torch.randperm(8)[0]]
-        #pdb.set_trace()
+                    speech = speech[:, :, torch.randperm(8)[0]]
+        # pdb.set_trace()
         batch_size = speech.shape[0]
         # for data-parallel
         text = text[:, : text_lengths.max()]
@@ -188,20 +195,19 @@
         return loss, stats, weight
 
     def collect_feats(
-        self,
-        speech: torch.Tensor,
-        speech_lengths: torch.Tensor,
-        text: torch.Tensor,
-        text_lengths: torch.Tensor,
+            self,
+            speech: torch.Tensor,
+            speech_lengths: torch.Tensor,
+            text: torch.Tensor,
+            text_lengths: torch.Tensor,
     ) -> Dict[str, torch.Tensor]:
         feats, feats_lengths, channel_size = self._extract_feats(speech, speech_lengths)
         return {"feats": feats, "feats_lengths": feats_lengths}
 
     def encode(
-        self, speech: torch.Tensor, speech_lengths: torch.Tensor
+            self, speech: torch.Tensor, speech_lengths: torch.Tensor
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         """Frontend + Encoder. Note that this method is used by asr_inference.py
-
         Args:
             speech: (Batch, Length, ...)
             speech_lengths: (Batch, )
@@ -220,14 +226,14 @@
         # Pre-encoder, e.g. used for raw input data
         if self.preencoder is not None:
             feats, feats_lengths = self.preencoder(feats, feats_lengths)
-        #pdb.set_trace()
+        # pdb.set_trace()
         encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths, channel_size)
 
         assert encoder_out.size(0) == speech.size(0), (
             encoder_out.size(),
             speech.size(0),
         )
-        if(encoder_out.dim()==4):
+        if (encoder_out.dim() == 4):
             assert encoder_out.size(2) <= encoder_out_lens.max(), (
                 encoder_out.size(),
                 encoder_out_lens.max(),
@@ -241,7 +247,7 @@
         return encoder_out, encoder_out_lens
 
     def _extract_feats(
-        self, speech: torch.Tensor, speech_lengths: torch.Tensor
+            self, speech: torch.Tensor, speech_lengths: torch.Tensor
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         assert speech_lengths.dim() == 1, speech_lengths.shape
         # for data-parallel
@@ -259,11 +265,11 @@
         return feats, feats_lengths, channel_size
 
     def _calc_att_loss(
-        self,
-        encoder_out: torch.Tensor,
-        encoder_out_lens: torch.Tensor,
-        ys_pad: torch.Tensor,
-        ys_pad_lens: torch.Tensor,
+            self,
+            encoder_out: torch.Tensor,
+            encoder_out_lens: torch.Tensor,
+            ys_pad: torch.Tensor,
+            ys_pad_lens: torch.Tensor,
     ):
         ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
         ys_in_lens = ys_pad_lens + 1
@@ -291,14 +297,14 @@
         return loss_att, acc_att, cer_att, wer_att
 
     def _calc_ctc_loss(
-        self,
-        encoder_out: torch.Tensor,
-        encoder_out_lens: torch.Tensor,
-        ys_pad: torch.Tensor,
-        ys_pad_lens: torch.Tensor,
+            self,
+            encoder_out: torch.Tensor,
+            encoder_out_lens: torch.Tensor,
+            ys_pad: torch.Tensor,
+            ys_pad_lens: torch.Tensor,
     ):
         # Calc CTC loss
-        if(encoder_out.dim()==4):
+        if (encoder_out.dim() == 4):
             encoder_out = encoder_out.mean(1)
         loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
 
@@ -310,10 +316,10 @@
         return loss_ctc, cer_ctc
 
     def _calc_rnnt_loss(
-        self,
-        encoder_out: torch.Tensor,
-        encoder_out_lens: torch.Tensor,
-        ys_pad: torch.Tensor,
-        ys_pad_lens: torch.Tensor,
+            self,
+            encoder_out: torch.Tensor,
+            encoder_out_lens: torch.Tensor,
+            ys_pad: torch.Tensor,
+            ys_pad_lens: torch.Tensor,
     ):
-        raise NotImplementedError
+        raise NotImplementedError
\ No newline at end of file
diff --git a/funasr/models/e2e_asr_paraformer.py b/funasr/models/e2e_asr_paraformer.py
index f414e4f..9d4f106 100644
--- a/funasr/models/e2e_asr_paraformer.py
+++ b/funasr/models/e2e_asr_paraformer.py
@@ -12,22 +12,25 @@
 import numpy as np
 from typeguard import check_argument_types
 
+from funasr.layers.abs_normalize import AbsNormalize
 from funasr.losses.label_smoothing_loss import (
     LabelSmoothingLoss,  # noqa: H301
 )
 from funasr.models.ctc import CTC
 from funasr.models.decoder.abs_decoder import AbsDecoder
 from funasr.models.e2e_asr_common import ErrorCalculator
+from funasr.models.encoder.abs_encoder import AbsEncoder
+from funasr.models.frontend.abs_frontend import AbsFrontend
 from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
 from funasr.models.predictor.cif import mae_loss
 from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
-from funasr.models.base_model import FunASRModel
+from funasr.models.specaug.abs_specaug import AbsSpecAug
 from funasr.modules.add_sos_eos import add_sos_eos
 from funasr.modules.nets_utils import make_pad_mask, pad_list
 from funasr.modules.nets_utils import th_accuracy
 from funasr.torch_utils.device_funcs import force_gatherable
+from funasr.models.base_model import FunASRModel
 from funasr.models.predictor.cif import CifPredictorV3
-
 
 if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
     from torch.cuda.amp import autocast
@@ -40,7 +43,7 @@
 
 class Paraformer(FunASRModel):
     """
-    Author: Speech Lab, Alibaba Group, China
+    Author: Speech Lab of DAMO Academy, Alibaba Group
     Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
     https://arxiv.org/abs/2206.08317
     """
@@ -49,10 +52,12 @@
             self,
             vocab_size: int,
             token_list: Union[Tuple[str, ...], List[str]],
-            frontend: Optional[torch.nn.Module],
-            specaug: Optional[torch.nn.Module],
-            normalize: Optional[torch.nn.Module],
-            encoder: torch.nn.Module,
+            frontend: Optional[AbsFrontend],
+            specaug: Optional[AbsSpecAug],
+            normalize: Optional[AbsNormalize],
+            preencoder: Optional[AbsPreEncoder],
+            encoder: AbsEncoder,
+            postencoder: Optional[AbsPostEncoder],
             decoder: AbsDecoder,
             ctc: CTC,
             ctc_weight: float = 0.5,
@@ -92,7 +97,16 @@
         self.frontend = frontend
         self.specaug = specaug
         self.normalize = normalize
+        self.preencoder = preencoder
+        self.postencoder = postencoder
         self.encoder = encoder
+
+        if not hasattr(self.encoder, "interctc_use_conditioning"):
+            self.encoder.interctc_use_conditioning = False
+        if self.encoder.interctc_use_conditioning:
+            self.encoder.conditioning_layer = torch.nn.Linear(
+                vocab_size, self.encoder.output_size()
+            )
 
         self.error_calculator = None
 
@@ -138,7 +152,6 @@
             text_lengths: torch.Tensor,
     ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
         """Frontend + Encoder + Decoder + Calc loss
-
         Args:
                 speech: (Batch, Length, ...)
                 speech_lengths: (Batch, )
@@ -161,7 +174,9 @@
 
         # 1. Encoder
         encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+        intermediate_outs = None
         if isinstance(encoder_out, tuple):
+            intermediate_outs = encoder_out[1]
             encoder_out = encoder_out[0]
 
         loss_att, acc_att, cer_att, wer_att = None, None, None, None
@@ -178,6 +193,30 @@
             # Collect CTC branch stats
             stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
             stats["cer_ctc"] = cer_ctc
+
+        # Intermediate CTC (optional)
+        loss_interctc = 0.0
+        if self.interctc_weight != 0.0 and intermediate_outs is not None:
+            for layer_idx, intermediate_out in intermediate_outs:
+                # we assume intermediate_out has the same length & padding
+                # as those of encoder_out
+                loss_ic, cer_ic = self._calc_ctc_loss(
+                    intermediate_out, encoder_out_lens, text, text_lengths
+                )
+                loss_interctc = loss_interctc + loss_ic
+
+                # Collect Intermedaite CTC stats
+                stats["loss_interctc_layer{}".format(layer_idx)] = (
+                    loss_ic.detach() if loss_ic is not None else None
+                )
+                stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
+
+            loss_interctc = loss_interctc / len(intermediate_outs)
+
+            # calculate whole encoder loss
+            loss_ctc = (
+                               1 - self.interctc_weight
+                       ) * loss_ctc + self.interctc_weight * loss_interctc
 
         # 2b. Attention decoder branch
         if self.ctc_weight != 1.0:
@@ -229,7 +268,6 @@
             self, speech: torch.Tensor, speech_lengths: torch.Tensor
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         """Frontend + Encoder. Note that this method is used by asr_inference.py
-
         Args:
                 speech: (Batch, Length, ...)
                 speech_lengths: (Batch, )
@@ -246,8 +284,29 @@
             if self.normalize is not None:
                 feats, feats_lengths = self.normalize(feats, feats_lengths)
 
+        # Pre-encoder, e.g. used for raw input data
+        if self.preencoder is not None:
+            feats, feats_lengths = self.preencoder(feats, feats_lengths)
+
         # 4. Forward encoder
-        encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
+        # feats: (Batch, Length, Dim)
+        # -> encoder_out: (Batch, Length2, Dim2)
+        if self.encoder.interctc_use_conditioning:
+            encoder_out, encoder_out_lens, _ = self.encoder(
+                feats, feats_lengths, ctc=self.ctc
+            )
+        else:
+            encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
+        intermediate_outs = None
+        if isinstance(encoder_out, tuple):
+            intermediate_outs = encoder_out[1]
+            encoder_out = encoder_out[0]
+
+        # Post-encoder, e.g. NLU
+        if self.postencoder is not None:
+            encoder_out, encoder_out_lens = self.postencoder(
+                encoder_out, encoder_out_lens
+            )
 
         assert encoder_out.size(0) == speech.size(0), (
             encoder_out.size(),
@@ -258,45 +317,18 @@
             encoder_out_lens.max(),
         )
 
+        if intermediate_outs is not None:
+            return (encoder_out, intermediate_outs), encoder_out_lens
+
         return encoder_out, encoder_out_lens
-
-    def encode_chunk(
-            self, speech: torch.Tensor, speech_lengths: torch.Tensor, cache: dict = None
-    ) -> Tuple[torch.Tensor, torch.Tensor]:
-        """Frontend + Encoder. Note that this method is used by asr_inference.py
-
-        Args:
-                speech: (Batch, Length, ...)
-                speech_lengths: (Batch, )
-        """
-        with autocast(False):
-            # 1. Extract feats
-            feats, feats_lengths = self._extract_feats(speech, speech_lengths)
-
-            # 2. Data augmentation
-            if self.specaug is not None and self.training:
-                feats, feats_lengths = self.specaug(feats, feats_lengths)
-
-            # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
-            if self.normalize is not None:
-                feats, feats_lengths = self.normalize(feats, feats_lengths)
-
-        # 4. Forward encoder
-        encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(feats, feats_lengths, cache=cache["encoder"])
-
-        return encoder_out, torch.tensor([encoder_out.size(1)])
 
     def calc_predictor(self, encoder_out, encoder_out_lens):
 
         encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
             encoder_out.device)
-        pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(encoder_out, None, encoder_out_mask,
-                                                                                  ignore_id=self.ignore_id)
-        return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
-
-    def calc_predictor_chunk(self, encoder_out, cache=None):
-
-        pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor.forward_chunk(encoder_out, cache["encoder"])
+        pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(encoder_out, None,
+                                                                                       encoder_out_mask,
+                                                                                       ignore_id=self.ignore_id)
         return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
 
     def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens):
@@ -307,14 +339,6 @@
         decoder_out = decoder_outs[0]
         decoder_out = torch.log_softmax(decoder_out, dim=-1)
         return decoder_out, ys_pad_lens
-
-    def cal_decoder_with_predictor_chunk(self, encoder_out, sematic_embeds, cache=None):
-        decoder_outs = self.decoder.forward_chunk(
-            encoder_out, sematic_embeds, cache["decoder"]
-        )
-        decoder_out = decoder_outs
-        decoder_out = torch.log_softmax(decoder_out, dim=-1)
-        return decoder_out
 
     def _extract_feats(
             self, speech: torch.Tensor, speech_lengths: torch.Tensor
@@ -342,9 +366,7 @@
             ys_pad_lens: torch.Tensor,
     ) -> torch.Tensor:
         """Compute negative log likelihood(nll) from transformer-decoder
-
         Normally, this function is called in batchify_nll.
-
         Args:
                 encoder_out: (Batch, Length, Dim)
                 encoder_out_lens: (Batch,)
@@ -381,7 +403,6 @@
             batch_size: int = 100,
     ):
         """Compute negative log likelihood(nll) from transformer-decoder
-
         To avoid OOM, this fuction seperate the input into batches.
         Then call nll for each batch and combine and return results.
         Args:
@@ -521,9 +542,186 @@
         return loss_ctc, cer_ctc
 
 
-class ParaformerBert(Paraformer):
+class ParaformerOnline(Paraformer):
     """
     Author: Speech Lab, Alibaba Group, China
+    Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
+    https://arxiv.org/abs/2206.08317
+    """
+
+    def __init__(
+            self, *args, **kwargs,
+    ):
+        super().__init__(*args, **kwargs)
+
+    def forward(
+            self,
+            speech: torch.Tensor,
+            speech_lengths: torch.Tensor,
+            text: torch.Tensor,
+            text_lengths: torch.Tensor,
+    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+        """Frontend + Encoder + Decoder + Calc loss
+        Args:
+                speech: (Batch, Length, ...)
+                speech_lengths: (Batch, )
+                text: (Batch, Length)
+                text_lengths: (Batch,)
+        """
+        assert text_lengths.dim() == 1, text_lengths.shape
+        # Check that batch_size is unified
+        assert (
+                speech.shape[0]
+                == speech_lengths.shape[0]
+                == text.shape[0]
+                == text_lengths.shape[0]
+        ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
+        batch_size = speech.shape[0]
+        self.step_cur += 1
+        # for data-parallel
+        text = text[:, : text_lengths.max()]
+        speech = speech[:, :speech_lengths.max()]
+
+        # 1. Encoder
+        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+        intermediate_outs = None
+        if isinstance(encoder_out, tuple):
+            intermediate_outs = encoder_out[1]
+            encoder_out = encoder_out[0]
+
+        loss_att, acc_att, cer_att, wer_att = None, None, None, None
+        loss_ctc, cer_ctc = None, None
+        loss_pre = None
+        stats = dict()
+
+        # 1. CTC branch
+        if self.ctc_weight != 0.0:
+            loss_ctc, cer_ctc = self._calc_ctc_loss(
+                encoder_out, encoder_out_lens, text, text_lengths
+            )
+
+            # Collect CTC branch stats
+            stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
+            stats["cer_ctc"] = cer_ctc
+
+        # Intermediate CTC (optional)
+        loss_interctc = 0.0
+        if self.interctc_weight != 0.0 and intermediate_outs is not None:
+            for layer_idx, intermediate_out in intermediate_outs:
+                # we assume intermediate_out has the same length & padding
+                # as those of encoder_out
+                loss_ic, cer_ic = self._calc_ctc_loss(
+                    intermediate_out, encoder_out_lens, text, text_lengths
+                )
+                loss_interctc = loss_interctc + loss_ic
+
+                # Collect Intermedaite CTC stats
+                stats["loss_interctc_layer{}".format(layer_idx)] = (
+                    loss_ic.detach() if loss_ic is not None else None
+                )
+                stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
+
+            loss_interctc = loss_interctc / len(intermediate_outs)
+
+            # calculate whole encoder loss
+            loss_ctc = (
+                               1 - self.interctc_weight
+                       ) * loss_ctc + self.interctc_weight * loss_interctc
+
+        # 2b. Attention decoder branch
+        if self.ctc_weight != 1.0:
+            loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_loss(
+                encoder_out, encoder_out_lens, text, text_lengths
+            )
+
+        # 3. CTC-Att loss definition
+        if self.ctc_weight == 0.0:
+            loss = loss_att + loss_pre * self.predictor_weight
+        elif self.ctc_weight == 1.0:
+            loss = loss_ctc
+        else:
+            loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
+
+        # Collect Attn branch stats
+        stats["loss_att"] = loss_att.detach() if loss_att is not None else None
+        stats["acc"] = acc_att
+        stats["cer"] = cer_att
+        stats["wer"] = wer_att
+        stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
+
+        stats["loss"] = torch.clone(loss.detach())
+
+        # force_gatherable: to-device and to-tensor if scalar for DataParallel
+        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+        return loss, stats, weight
+
+    def encode_chunk(
+            self, speech: torch.Tensor, speech_lengths: torch.Tensor, cache: dict = None
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Frontend + Encoder. Note that this method is used by asr_inference.py
+        Args:
+                speech: (Batch, Length, ...)
+                speech_lengths: (Batch, )
+        """
+        with autocast(False):
+            # 1. Extract feats
+            feats, feats_lengths = self._extract_feats(speech, speech_lengths)
+
+            # 2. Data augmentation
+            if self.specaug is not None and self.training:
+                feats, feats_lengths = self.specaug(feats, feats_lengths)
+
+            # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
+            if self.normalize is not None:
+                feats, feats_lengths = self.normalize(feats, feats_lengths)
+
+        # Pre-encoder, e.g. used for raw input data
+        if self.preencoder is not None:
+            feats, feats_lengths = self.preencoder(feats, feats_lengths)
+
+        # 4. Forward encoder
+        # feats: (Batch, Length, Dim)
+        # -> encoder_out: (Batch, Length2, Dim2)
+        if self.encoder.interctc_use_conditioning:
+            encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(
+                feats, feats_lengths, cache=cache["encoder"], ctc=self.ctc
+            )
+        else:
+            encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(feats, feats_lengths, cache=cache["encoder"])
+        intermediate_outs = None
+        if isinstance(encoder_out, tuple):
+            intermediate_outs = encoder_out[1]
+            encoder_out = encoder_out[0]
+
+        # Post-encoder, e.g. NLU
+        if self.postencoder is not None:
+            encoder_out, encoder_out_lens = self.postencoder(
+                encoder_out, encoder_out_lens
+            )
+
+        if intermediate_outs is not None:
+            return (encoder_out, intermediate_outs), encoder_out_lens
+
+        return encoder_out, torch.tensor([encoder_out.size(1)])
+
+    def calc_predictor_chunk(self, encoder_out, cache=None):
+
+        pre_acoustic_embeds, pre_token_length = \
+            self.predictor.forward_chunk(encoder_out, cache["encoder"])
+        return pre_acoustic_embeds, pre_token_length
+
+    def cal_decoder_with_predictor_chunk(self, encoder_out, sematic_embeds, cache=None):
+        decoder_outs = self.decoder.forward_chunk(
+            encoder_out, sematic_embeds, cache["decoder"]
+        )
+        decoder_out = decoder_outs
+        decoder_out = torch.log_softmax(decoder_out, dim=-1)
+        return decoder_out
+
+
+class ParaformerBert(Paraformer):
+    """
+    Author: Speech Lab of DAMO Academy, Alibaba Group
     Paraformer2: advanced paraformer with LFMMI and bert for non-autoregressive end-to-end speech recognition
     """
 
@@ -531,11 +729,11 @@
             self,
             vocab_size: int,
             token_list: Union[Tuple[str, ...], List[str]],
-            frontend: Optional[torch.nn.Module],
-            specaug: Optional[torch.nn.Module],
-            normalize: Optional[torch.nn.Module],
+            frontend: Optional[AbsFrontend],
+            specaug: Optional[AbsSpecAug],
+            normalize: Optional[AbsNormalize],
             preencoder: Optional[AbsPreEncoder],
-            encoder: torch.nn.Module,
+            encoder: AbsEncoder,
             postencoder: Optional[AbsPostEncoder],
             decoder: AbsDecoder,
             ctc: CTC,
@@ -690,7 +888,6 @@
             embed_lengths: torch.Tensor = None,
     ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
         """Frontend + Encoder + Decoder + Calc loss
-
         Args:
                 speech: (Batch, Length, ...)
                 speech_lengths: (Batch, )
@@ -799,74 +996,73 @@
 
 
 class BiCifParaformer(Paraformer):
-
     """
     Paraformer model with an extra cif predictor
     to conduct accurate timestamp prediction
     """
 
     def __init__(
-        self,
-        vocab_size: int,
-        token_list: Union[Tuple[str, ...], List[str]],
-        frontend: Optional[torch.nn.Module],
-        specaug: Optional[torch.nn.Module],
-        normalize: Optional[torch.nn.Module],
-        preencoder: Optional[AbsPreEncoder],
-        encoder: torch.nn.Module,
-        postencoder: Optional[AbsPostEncoder],
-        decoder: AbsDecoder,
-        ctc: CTC,
-        ctc_weight: float = 0.5,
-        interctc_weight: float = 0.0,
-        ignore_id: int = -1,
-        blank_id: int = 0,
-        sos: int = 1,
-        eos: int = 2,
-        lsm_weight: float = 0.0,
-        length_normalized_loss: bool = False,
-        report_cer: bool = True,
-        report_wer: bool = True,
-        sym_space: str = "<space>",
-        sym_blank: str = "<blank>",
-        extract_feats_in_collect_stats: bool = True,
-        predictor = None,
-        predictor_weight: float = 0.0,
-        predictor_bias: int = 0,
-        sampling_ratio: float = 0.2,
+            self,
+            vocab_size: int,
+            token_list: Union[Tuple[str, ...], List[str]],
+            frontend: Optional[AbsFrontend],
+            specaug: Optional[AbsSpecAug],
+            normalize: Optional[AbsNormalize],
+            preencoder: Optional[AbsPreEncoder],
+            encoder: AbsEncoder,
+            postencoder: Optional[AbsPostEncoder],
+            decoder: AbsDecoder,
+            ctc: CTC,
+            ctc_weight: float = 0.5,
+            interctc_weight: float = 0.0,
+            ignore_id: int = -1,
+            blank_id: int = 0,
+            sos: int = 1,
+            eos: int = 2,
+            lsm_weight: float = 0.0,
+            length_normalized_loss: bool = False,
+            report_cer: bool = True,
+            report_wer: bool = True,
+            sym_space: str = "<space>",
+            sym_blank: str = "<blank>",
+            extract_feats_in_collect_stats: bool = True,
+            predictor=None,
+            predictor_weight: float = 0.0,
+            predictor_bias: int = 0,
+            sampling_ratio: float = 0.2,
     ):
         assert check_argument_types()
         assert 0.0 <= ctc_weight <= 1.0, ctc_weight
         assert 0.0 <= interctc_weight < 1.0, interctc_weight
 
         super().__init__(
-        vocab_size=vocab_size,
-        token_list=token_list,
-        frontend=frontend,
-        specaug=specaug,
-        normalize=normalize,
-        preencoder=preencoder,
-        encoder=encoder,
-        postencoder=postencoder,
-        decoder=decoder,
-        ctc=ctc,
-        ctc_weight=ctc_weight,
-        interctc_weight=interctc_weight,
-        ignore_id=ignore_id,
-        blank_id=blank_id,
-        sos=sos,
-        eos=eos,
-        lsm_weight=lsm_weight,
-        length_normalized_loss=length_normalized_loss,
-        report_cer=report_cer,
-        report_wer=report_wer,
-        sym_space=sym_space,
-        sym_blank=sym_blank,
-        extract_feats_in_collect_stats=extract_feats_in_collect_stats,
-        predictor=predictor,
-        predictor_weight=predictor_weight,
-        predictor_bias=predictor_bias,
-        sampling_ratio=sampling_ratio,
+            vocab_size=vocab_size,
+            token_list=token_list,
+            frontend=frontend,
+            specaug=specaug,
+            normalize=normalize,
+            preencoder=preencoder,
+            encoder=encoder,
+            postencoder=postencoder,
+            decoder=decoder,
+            ctc=ctc,
+            ctc_weight=ctc_weight,
+            interctc_weight=interctc_weight,
+            ignore_id=ignore_id,
+            blank_id=blank_id,
+            sos=sos,
+            eos=eos,
+            lsm_weight=lsm_weight,
+            length_normalized_loss=length_normalized_loss,
+            report_cer=report_cer,
+            report_wer=report_wer,
+            sym_space=sym_space,
+            sym_blank=sym_blank,
+            extract_feats_in_collect_stats=extract_feats_in_collect_stats,
+            predictor=predictor,
+            predictor_weight=predictor_weight,
+            predictor_bias=predictor_bias,
+            sampling_ratio=sampling_ratio,
         )
         assert isinstance(self.predictor, CifPredictorV3), "BiCifParaformer should use CIFPredictorV3"
 
@@ -888,21 +1084,77 @@
         loss_pre2 = self.criterion_pre(ys_pad_lens.type_as(pre_token_length2), pre_token_length2)
 
         return loss_pre2
-    
+
+    def _calc_att_loss(
+            self,
+            encoder_out: torch.Tensor,
+            encoder_out_lens: torch.Tensor,
+            ys_pad: torch.Tensor,
+            ys_pad_lens: torch.Tensor,
+    ):
+        encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+            encoder_out.device)
+        if self.predictor_bias == 1:
+            _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
+            ys_pad_lens = ys_pad_lens + self.predictor_bias
+        pre_acoustic_embeds, pre_token_length, _, pre_peak_index, _ = self.predictor(encoder_out, ys_pad,
+                                                                                     encoder_out_mask,
+                                                                                     ignore_id=self.ignore_id)
+
+        # 0. sampler
+        decoder_out_1st = None
+        if self.sampling_ratio > 0.0:
+            if self.step_cur < 2:
+                logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
+            sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
+                                                           pre_acoustic_embeds)
+        else:
+            if self.step_cur < 2:
+                logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
+            sematic_embeds = pre_acoustic_embeds
+
+        # 1. Forward decoder
+        decoder_outs = self.decoder(
+            encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
+        )
+        decoder_out, _ = decoder_outs[0], decoder_outs[1]
+
+        if decoder_out_1st is None:
+            decoder_out_1st = decoder_out
+        # 2. Compute attention loss
+        loss_att = self.criterion_att(decoder_out, ys_pad)
+        acc_att = th_accuracy(
+            decoder_out_1st.view(-1, self.vocab_size),
+            ys_pad,
+            ignore_label=self.ignore_id,
+        )
+        loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
+
+        # Compute cer/wer using attention-decoder
+        if self.training or self.error_calculator is None:
+            cer_att, wer_att = None, None
+        else:
+            ys_hat = decoder_out_1st.argmax(dim=-1)
+            cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
+
+        return loss_att, acc_att, cer_att, wer_att, loss_pre
+
     def calc_predictor(self, encoder_out, encoder_out_lens):
 
         encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
             encoder_out.device)
-        pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index, pre_token_length2 = self.predictor(encoder_out, None, encoder_out_mask,
-                                                                                  ignore_id=self.ignore_id)
+        pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index, pre_token_length2 = self.predictor(encoder_out,
+                                                                                                          None,
+                                                                                                          encoder_out_mask,
+                                                                                                          ignore_id=self.ignore_id)
         return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
-    
+
     def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num):
         encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
             encoder_out.device)
         ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(encoder_out,
-                                                                                               encoder_out_mask,
-                                                                                               token_num)
+                                                                                            encoder_out_mask,
+                                                                                            token_num)
         return ds_alphas, ds_cif_peak, us_alphas, us_peaks
 
     def forward(
@@ -913,7 +1165,6 @@
             text_lengths: torch.Tensor,
     ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
         """Frontend + Encoder + Decoder + Calc loss
-
         Args:
                 speech: (Batch, Length, ...)
                 speech_lengths: (Batch, )
@@ -996,7 +1247,8 @@
         elif self.ctc_weight == 1.0:
             loss = loss_ctc
         else:
-            loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight * 0.5
+            loss = self.ctc_weight * loss_ctc + (
+                        1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight * 0.5
 
         # Collect Attn branch stats
         stats["loss_att"] = loss_att.detach() if loss_att is not None else None
@@ -1022,11 +1274,11 @@
             self,
             vocab_size: int,
             token_list: Union[Tuple[str, ...], List[str]],
-            frontend: Optional[torch.nn.Module],
-            specaug: Optional[torch.nn.Module],
-            normalize: Optional[torch.nn.Module],
+            frontend: Optional[AbsFrontend],
+            specaug: Optional[AbsSpecAug],
+            normalize: Optional[AbsNormalize],
             preencoder: Optional[AbsPreEncoder],
-            encoder: torch.nn.Module,
+            encoder: AbsEncoder,
             postencoder: Optional[AbsPostEncoder],
             decoder: AbsDecoder,
             ctc: CTC,
@@ -1120,7 +1372,6 @@
             text_lengths: torch.Tensor,
     ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
         """Frontend + Encoder + Decoder + Calc loss
-
         Args:
                 speech: (Batch, Length, ...)
                 speech_lengths: (Batch, )
@@ -1504,4 +1755,4 @@
                     "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
                                                                                   var_dict_tf[name_tf].shape))
 
-        return var_dict_torch_update
+        return var_dict_torch_update
\ No newline at end of file
diff --git a/funasr/models/e2e_diar_eend_ola.py b/funasr/models/e2e_diar_eend_ola.py
index b4a3fa2..da7c674 100644
--- a/funasr/models/e2e_diar_eend_ola.py
+++ b/funasr/models/e2e_diar_eend_ola.py
@@ -15,8 +15,8 @@
 from funasr.modules.eend_ola.encoder import EENDOLATransformerEncoder
 from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor
 from funasr.modules.eend_ola.utils.power import generate_mapping_dict
-from funasr.models.base_model import FunASRModel
 from funasr.torch_utils.device_funcs import force_gatherable
+from funasr.models.base_model import FunASRModel
 
 if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
     pass
@@ -91,7 +91,6 @@
             text_lengths: torch.Tensor,
     ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
         """Frontend + Encoder + Decoder + Calc loss
-
         Args:
             speech: (Batch, Length, ...)
             speech_lengths: (Batch, )
diff --git a/funasr/models/e2e_diar_sond.py b/funasr/models/e2e_diar_sond.py
index dc7135f..9c3fb92 100644
--- a/funasr/models/e2e_diar_sond.py
+++ b/funasr/models/e2e_diar_sond.py
@@ -14,9 +14,15 @@
 from torch.nn import functional as F
 from typeguard import check_argument_types
 
+from funasr.modules.nets_utils import to_device
 from funasr.modules.nets_utils import make_pad_mask
-from funasr.models.base_model import FunASRModel
+from funasr.models.decoder.abs_decoder import AbsDecoder
+from funasr.models.encoder.abs_encoder import AbsEncoder
+from funasr.models.frontend.abs_frontend import AbsFrontend
+from funasr.models.specaug.abs_specaug import AbsSpecAug
+from funasr.layers.abs_normalize import AbsNormalize
 from funasr.torch_utils.device_funcs import force_gatherable
+from funasr.models.base_model import FunASRModel
 from funasr.losses.label_smoothing_loss import LabelSmoothingLoss, SequenceBinaryCrossEntropy
 from funasr.utils.misc import int2vec
 
@@ -30,16 +36,20 @@
 
 
 class DiarSondModel(FunASRModel):
-    """Speaker overlap-aware neural diarization model
-    reference: https://arxiv.org/abs/2211.10243
+    """
+    Author: Speech Lab, Alibaba Group, China
+    SOND: Speaker Overlap-aware Neural Diarization for Multi-party Meeting Analysis
+    https://arxiv.org/abs/2211.10243
+    TOLD: A Novel Two-Stage Overlap-Aware Framework for Speaker Diarization
+    https://arxiv.org/abs/2303.05397
     """
 
     def __init__(
         self,
         vocab_size: int,
-        frontend: Optional[torch.nn.Module],
-        specaug: Optional[torch.nn.Module],
-        normalize: Optional[torch.nn.Module],
+        frontend: Optional[AbsFrontend],
+        specaug: Optional[AbsSpecAug],
+        normalize: Optional[AbsNormalize],
         encoder: torch.nn.Module,
         speaker_encoder: Optional[torch.nn.Module],
         ci_scorer: torch.nn.Module,
@@ -105,7 +115,6 @@
         binary_labels_lengths: torch.Tensor = None,
     ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
         """Frontend + Encoder + Speaker Encoder + CI Scorer + CD Scorer + Decoder + Calc loss
-
         Args:
             speech: (Batch, samples) or (Batch, frames, input_size)
             speech_lengths: (Batch,) default None for chunk interator,
@@ -342,7 +351,7 @@
         cd_simi = torch.reshape(cd_simi, [bb, self.max_spk_num, tt, 1])
         cd_simi = cd_simi.squeeze(dim=3).permute([0, 2, 1])
 
-        if isinstance(self.ci_scorer, torch.nn.Module):
+        if isinstance(self.ci_scorer, AbsEncoder):
             ci_simi = self.ci_scorer(ge_in, ge_len)[0]
             ci_simi = torch.reshape(ci_simi, [bb, self.max_spk_num, tt]).permute([0, 2, 1])
         else:
@@ -381,7 +390,6 @@
         self, speech: torch.Tensor, speech_lengths: torch.Tensor
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         """Frontend + Encoder
-
         Args:
             speech: (Batch, Length, ...)
             speech_lengths: (Batch,)
@@ -481,4 +489,4 @@
             speaker_miss,
             speaker_falarm,
             speaker_error,
-        )
+        )
\ No newline at end of file
diff --git a/funasr/models/e2e_sv.py b/funasr/models/e2e_sv.py
index 582c25d..bd82c7c 100644
--- a/funasr/models/e2e_sv.py
+++ b/funasr/models/e2e_sv.py
@@ -1,3 +1,8 @@
+
+"""
+Author: Speech Lab, Alibaba Group, China
+"""
+
 import logging
 from contextlib import contextmanager
 from distutils.version import LooseVersion
@@ -10,11 +15,22 @@
 import torch
 from typeguard import check_argument_types
 
+from funasr.layers.abs_normalize import AbsNormalize
+from funasr.losses.label_smoothing_loss import (
+    LabelSmoothingLoss,  # noqa: H301
+)
+from funasr.models.ctc import CTC
 from funasr.models.decoder.abs_decoder import AbsDecoder
+from funasr.models.encoder.abs_encoder import AbsEncoder
+from funasr.models.frontend.abs_frontend import AbsFrontend
 from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
 from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
-from funasr.models.base_model import FunASRModel
+from funasr.models.specaug.abs_specaug import AbsSpecAug
+from funasr.modules.add_sos_eos import add_sos_eos
+from funasr.modules.e2e_asr_common import ErrorCalculator
+from funasr.modules.nets_utils import th_accuracy
 from funasr.torch_utils.device_funcs import force_gatherable
+from funasr.models.base_model import FunASRModel
 
 if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
     from torch.cuda.amp import autocast
@@ -32,11 +48,11 @@
             self,
             vocab_size: int,
             token_list: Union[Tuple[str, ...], List[str]],
-            frontend: Optional[torch.nn.Module],
-            specaug: Optional[torch.nn.Module],
-            normalize: Optional[torch.nn.Module],
+            frontend: Optional[AbsFrontend],
+            specaug: Optional[AbsSpecAug],
+            normalize: Optional[AbsNormalize],
             preencoder: Optional[AbsPreEncoder],
-            encoder: torch.nn.Module,
+            encoder: AbsEncoder,
             postencoder: Optional[AbsPostEncoder],
             pooling_layer: torch.nn.Module,
             decoder: AbsDecoder,
@@ -65,7 +81,6 @@
             text_lengths: torch.Tensor,
     ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
         """Frontend + Encoder + Decoder + Calc loss
-
         Args:
             speech: (Batch, Length, ...)
             speech_lengths: (Batch, )
@@ -206,7 +221,6 @@
             self, speech: torch.Tensor, speech_lengths: torch.Tensor
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         """Frontend + Encoder. Note that this method is used by asr_inference.py
-
         Args:
             speech: (Batch, Length, ...)
             speech_lengths: (Batch, )
@@ -256,4 +270,4 @@
         else:
             # No frontend and no feature extract
             feats, feats_lengths = speech, speech_lengths
-        return feats, feats_lengths
+        return feats, feats_lengths
\ No newline at end of file
diff --git a/funasr/models/e2e_tp.py b/funasr/models/e2e_tp.py
index c5dc63c..39419c8 100644
--- a/funasr/models/e2e_tp.py
+++ b/funasr/models/e2e_tp.py
@@ -2,19 +2,23 @@
 from contextlib import contextmanager
 from distutils.version import LooseVersion
 from typing import Dict
+from typing import List
 from typing import Optional
 from typing import Tuple
+from typing import Union
 
 import torch
+import numpy as np
 from typeguard import check_argument_types
 
+from funasr.models.encoder.abs_encoder import AbsEncoder
+from funasr.models.frontend.abs_frontend import AbsFrontend
 from funasr.models.predictor.cif import mae_loss
-from funasr.models.base_model import FunASRModel
 from funasr.modules.add_sos_eos import add_sos_eos
 from funasr.modules.nets_utils import make_pad_mask, pad_list
 from funasr.torch_utils.device_funcs import force_gatherable
+from funasr.models.base_model import FunASRModel
 from funasr.models.predictor.cif import CifPredictorV3
-
 
 if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
     from torch.cuda.amp import autocast
@@ -25,15 +29,15 @@
         yield
 
 
-class TimestampPredictor(FunASRModel):
+class TimestampPredictor(AbsESPnetModel):
     """
-    Author: Speech Lab, Alibaba Group, China
+    Author: Speech Lab of DAMO Academy, Alibaba Group
     """
 
     def __init__(
             self,
-            frontend: Optional[torch.nn.Module],
-            encoder: torch.nn.Module,
+            frontend: Optional[AbsFrontend],
+            encoder: AbsEncoder,
             predictor: CifPredictorV3,
             predictor_bias: int = 0,
             token_list=None,
@@ -51,7 +55,7 @@
         self.predictor_bias = predictor_bias
         self.criterion_pre = mae_loss()
         self.token_list = token_list
-    
+
     def forward(
             self,
             speech: torch.Tensor,
@@ -60,7 +64,6 @@
             text_lengths: torch.Tensor,
     ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
         """Frontend + Encoder + Decoder + Calc loss
-
         Args:
                 speech: (Batch, Length, ...)
                 speech_lengths: (Batch, )
@@ -108,7 +111,6 @@
             self, speech: torch.Tensor, speech_lengths: torch.Tensor
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         """Frontend + Encoder. Note that this method is used by asr_inference.py
-
         Args:
                 speech: (Batch, Length, ...)
                 speech_lengths: (Batch, )
@@ -123,7 +125,7 @@
         encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
 
         return encoder_out, encoder_out_lens
-    
+
     def _extract_feats(
             self, speech: torch.Tensor, speech_lengths: torch.Tensor
     ) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -146,8 +148,8 @@
         encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
             encoder_out.device)
         ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(encoder_out,
-                                                                                               encoder_out_mask,
-                                                                                               token_num)
+                                                                                            encoder_out_mask,
+                                                                                            token_num)
         return ds_alphas, ds_cif_peak, us_alphas, us_peaks
 
     def collect_feats(
diff --git a/funasr/models/e2e_uni_asr.py b/funasr/models/e2e_uni_asr.py
index 0c53389..d08ea37 100644
--- a/funasr/models/e2e_uni_asr.py
+++ b/funasr/models/e2e_uni_asr.py
@@ -17,10 +17,13 @@
     LabelSmoothingLoss,  # noqa: H301
 )
 from funasr.models.ctc import CTC
-from funasr.models.encoder.abs_encoder import AbsEncoder
 from funasr.models.decoder.abs_decoder import AbsDecoder
+from funasr.models.encoder.abs_encoder import AbsEncoder
+from funasr.models.frontend.abs_frontend import AbsFrontend
 from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
 from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
+from funasr.models.specaug.abs_specaug import AbsSpecAug
+from funasr.layers.abs_normalize import AbsNormalize
 from funasr.torch_utils.device_funcs import force_gatherable
 from funasr.models.base_model import FunASRModel
 from funasr.modules.streaming_utils.chunk_utilis import sequence_mask
@@ -37,18 +40,18 @@
 
 class UniASR(FunASRModel):
     """
-    Author: Speech Lab, Alibaba Group, China
+    Author: Speech Lab of DAMO Academy, Alibaba Group
     """
 
     def __init__(
         self,
         vocab_size: int,
         token_list: Union[Tuple[str, ...], List[str]],
-        frontend: Optional[torch.nn.Module],
-        specaug: Optional[torch.nn.Module],
-        normalize: Optional[torch.nn.Module],
+        frontend: Optional[AbsFrontend],
+        specaug: Optional[AbsSpecAug],
+        normalize: Optional[AbsNormalize],
         preencoder: Optional[AbsPreEncoder],
-        encoder: torch.nn.Module,
+        encoder: AbsEncoder,
         postencoder: Optional[AbsPostEncoder],
         decoder: AbsDecoder,
         ctc: CTC,
@@ -176,7 +179,6 @@
         decoding_ind: int = None,
     ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
         """Frontend + Encoder + Decoder + Calc loss
-
         Args:
                         speech: (Batch, Length, ...)
                         speech_lengths: (Batch, )
@@ -466,7 +468,6 @@
         self, speech: torch.Tensor, speech_lengths: torch.Tensor, ind: int = 0,
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         """Frontend + Encoder. Note that this method is used by asr_inference.py
-
         Args:
                         speech: (Batch, Length, ...)
                         speech_lengths: (Batch, )
@@ -530,7 +531,6 @@
         ind: int = 0,
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         """Frontend + Encoder. Note that this method is used by asr_inference.py
-
         Args:
                         speech: (Batch, Length, ...)
                         speech_lengths: (Batch, )
@@ -624,9 +624,7 @@
         ys_pad_lens: torch.Tensor,
     ) -> torch.Tensor:
         """Compute negative log likelihood(nll) from transformer-decoder
-
         Normally, this function is called in batchify_nll.
-
         Args:
                         encoder_out: (Batch, Length, Dim)
                         encoder_out_lens: (Batch,)
@@ -663,7 +661,6 @@
         batch_size: int = 100,
     ):
         """Compute negative log likelihood(nll) from transformer-decoder
-
         To avoid OOM, this fuction seperate the input into batches.
         Then call nll for each batch and combine and return results.
         Args:
@@ -1069,4 +1066,3 @@
             ys_hat = self.ctc2.argmax(encoder_out).data
             cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
         return loss_ctc, cer_ctc
-
diff --git a/funasr/models/e2e_vad.py b/funasr/models/e2e_vad.py
index ff37429..e477750 100644
--- a/funasr/models/e2e_vad.py
+++ b/funasr/models/e2e_vad.py
@@ -35,6 +35,12 @@
 
 
 class VADXOptions:
+    """
+    Author: Speech Lab of DAMO Academy, Alibaba Group
+    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+    https://arxiv.org/abs/1803.05030
+    """
+
     def __init__(
             self,
             sample_rate: int = 16000,
@@ -99,6 +105,12 @@
 
 
 class E2EVadSpeechBufWithDoa(object):
+    """
+    Author: Speech Lab of DAMO Academy, Alibaba Group
+    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+    https://arxiv.org/abs/1803.05030
+    """
+
     def __init__(self):
         self.start_ms = 0
         self.end_ms = 0
@@ -117,6 +129,12 @@
 
 
 class E2EVadFrameProb(object):
+    """
+    Author: Speech Lab of DAMO Academy, Alibaba Group
+    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+    https://arxiv.org/abs/1803.05030
+    """
+
     def __init__(self):
         self.noise_prob = 0.0
         self.speech_prob = 0.0
@@ -126,6 +144,12 @@
 
 
 class WindowDetector(object):
+    """
+    Author: Speech Lab of DAMO Academy, Alibaba Group
+    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+    https://arxiv.org/abs/1803.05030
+    """
+
     def __init__(self, window_size_ms: int, sil_to_speech_time: int,
                  speech_to_sil_time: int, frame_size_ms: int):
         self.window_size_ms = window_size_ms
@@ -192,6 +216,12 @@
 
 
 class E2EVadModel(nn.Module):
+    """
+    Author: Speech Lab of DAMO Academy, Alibaba Group
+    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+    https://arxiv.org/abs/1803.05030
+    """
+
     def __init__(self, encoder: FSMN, vad_post_args: Dict[str, Any], frontend=None):
         super(E2EVadModel, self).__init__()
         self.vad_opts = VADXOptions(**vad_post_args)
@@ -286,7 +316,7 @@
                                 0.000001))
 
     def ComputeScores(self, feats: torch.Tensor, in_cache: Dict[str, torch.Tensor]) -> None:
-        scores = self.encoder(feats, in_cache)  # return B * T * D
+        scores = self.encoder(feats, in_cache).to('cpu')  # return B * T * D
         assert scores.shape[1] == feats.shape[1], "The shape between feats and scores does not match"
         self.vad_opts.nn_eval_block_size = scores.shape[1]
         self.frm_cnt += scores.shape[1]  # count total frames
@@ -444,7 +474,7 @@
                         - 1)) / self.vad_opts.noise_frame_num_used_for_snr
 
         return frame_state
-     
+
     def forward(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(),
                 is_final: bool = False
                 ) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]:
@@ -460,8 +490,9 @@
             segment_batch = []
             if len(self.output_data_buf) > 0:
                 for i in range(self.output_data_buf_offset, len(self.output_data_buf)):
-                    if not self.output_data_buf[i].contain_seg_start_point or not self.output_data_buf[
-                        i].contain_seg_end_point:
+                    if not is_final and (
+                            not self.output_data_buf[i].contain_seg_start_point or not self.output_data_buf[
+                        i].contain_seg_end_point):
                         continue
                     segment = [self.output_data_buf[i].start_ms, self.output_data_buf[i].end_ms]
                     segment_batch.append(segment)
@@ -474,11 +505,11 @@
         return segments, in_cache
 
     def forward_online(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(),
-                is_final: bool = False, max_end_sil: int = 800
-                ) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]:
+                       is_final: bool = False, max_end_sil: int = 800
+                       ) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]:
         self.max_end_sil_frame_cnt_thresh = max_end_sil - self.vad_opts.speech_to_sil_time_thres
         self.waveform = waveform  # compute decibel for each frame
-        
+
         self.ComputeScores(feats, in_cache)
         self.ComputeDecibel()
         if not is_final:

--
Gitblit v1.9.1