From 256035b6c1fa6115b6f33972ed243eb43f3e4299 Mon Sep 17 00:00:00 2001
From: aky15 <ankeyu.aky@11.17.44.249>
Date: 星期五, 14 四月 2023 11:38:00 +0800
Subject: [PATCH] rnnt reorg

---
 funasr/modules/embedding.py                |   77 +++
 funasr/models/e2e_transducer_unified.py    |    2 
 /dev/null                                  |  171 -------
 funasr/models/encoder/conformer_encoder.py |  640 ++++++++++++++++++++++++++
 funasr/modules/attention.py                |  220 +++++++++
 funasr/modules/repeat.py                   |   92 +++
 funasr/modules/subsampling.py              |  202 ++++++++
 funasr/tasks/asr_transducer.py             |    6 
 funasr/models/e2e_transducer.py            |    2 
 funasr/modules/normalization.py            |    0 
 10 files changed, 1,233 insertions(+), 179 deletions(-)

diff --git a/funasr/models/e2e_transducer.py b/funasr/models/e2e_transducer.py
index b669c9d..8630aec 100644
--- a/funasr/models/e2e_transducer.py
+++ b/funasr/models/e2e_transducer.py
@@ -12,7 +12,7 @@
 from funasr.models.specaug.abs_specaug import AbsSpecAug
 from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder
 from funasr.models.decoder.abs_decoder import AbsDecoder as AbsAttDecoder
-from funasr.models.encoder.chunk_encoder import ChunkEncoder as Encoder
+from funasr.models.encoder.conformer_encoder import ConformerChunkEncoder as Encoder
 from funasr.models.joint_network import JointNetwork
 from funasr.modules.nets_utils import get_transducer_task_io
 from funasr.layers.abs_normalize import AbsNormalize
diff --git a/funasr/models/e2e_transducer_unified.py b/funasr/models/e2e_transducer_unified.py
index 6003542..124bc09 100644
--- a/funasr/models/e2e_transducer_unified.py
+++ b/funasr/models/e2e_transducer_unified.py
@@ -11,7 +11,7 @@
 from funasr.models.frontend.abs_frontend import AbsFrontend
 from funasr.models.specaug.abs_specaug import AbsSpecAug
 from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder
-from funasr.models.encoder.chunk_encoder import ChunkEncoder as Encoder
+from funasr.models.encoder.conformer_encoder import ConformerChunkEncoder as Encoder
 from funasr.models.joint_network import JointNetwork
 from funasr.modules.nets_utils import get_transducer_task_io
 from funasr.layers.abs_normalize import AbsNormalize
diff --git a/funasr/models/encoder/chunk_encoder.py b/funasr/models/encoder/chunk_encoder.py
deleted file mode 100644
index c6fc292..0000000
--- a/funasr/models/encoder/chunk_encoder.py
+++ /dev/null
@@ -1,292 +0,0 @@
-from typing import Any, Dict, List, Tuple
-
-import torch
-from typeguard import check_argument_types
-
-from funasr.models.encoder.chunk_encoder_utils.building import (
-    build_body_blocks,
-    build_input_block,
-    build_main_parameters,
-    build_positional_encoding,
-)
-from funasr.models.encoder.chunk_encoder_utils.validation import validate_architecture
-from funasr.modules.nets_utils import (
-    TooShortUttError,
-    check_short_utt,
-    make_chunk_mask,
-    make_source_mask,
-)
-
-class ChunkEncoder(torch.nn.Module):
-    """Encoder module definition.
-
-    Args:
-        input_size: Input size.
-        body_conf: Encoder body configuration.
-        input_conf: Encoder input configuration.
-        main_conf: Encoder main configuration.
-
-    """
-
-    def __init__(
-        self,
-        input_size: int,
-        body_conf: List[Dict[str, Any]],
-        input_conf: Dict[str, Any] = {},
-        main_conf: Dict[str, Any] = {},
-    ) -> None:
-        """Construct an Encoder object."""
-        super().__init__()
-
-        assert check_argument_types()
-
-        embed_size, output_size = validate_architecture(
-            input_conf, body_conf, input_size
-        )
-        main_params = build_main_parameters(**main_conf)
-
-        self.embed = build_input_block(input_size, input_conf)
-        self.pos_enc = build_positional_encoding(embed_size, main_params)
-        self.encoders = build_body_blocks(body_conf, main_params, output_size)
-
-        self.output_size = output_size
-
-        self.dynamic_chunk_training = main_params["dynamic_chunk_training"]
-        self.short_chunk_threshold = main_params["short_chunk_threshold"]
-        self.short_chunk_size = main_params["short_chunk_size"]
-        self.left_chunk_size = main_params["left_chunk_size"]
-
-        self.unified_model_training = main_params["unified_model_training"]
-        self.default_chunk_size = main_params["default_chunk_size"]
-        self.jitter_range = main_params["jitter_range"]
-
-        self.time_reduction_factor = main_params["time_reduction_factor"]
-    def get_encoder_input_raw_size(self, size: int, hop_length: int) -> int:
-        """Return the corresponding number of sample for a given chunk size, in frames.
-
-        Where size is the number of features frames after applying subsampling.
-
-        Args:
-            size: Number of frames after subsampling.
-            hop_length: Frontend's hop length
-
-        Returns:
-            : Number of raw samples
-
-        """
-        return self.embed.get_size_before_subsampling(size) * hop_length
-
-    def get_encoder_input_size(self, size: int) -> int:
-        """Return the corresponding number of sample for a given chunk size, in frames.
-
-        Where size is the number of features frames after applying subsampling.
-
-        Args:
-            size: Number of frames after subsampling.
-
-        Returns:
-            : Number of raw samples
-
-        """
-        return self.embed.get_size_before_subsampling(size)
-
-
-    def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
-        """Initialize/Reset encoder streaming cache.
-
-        Args:
-            left_context: Number of frames in left context.
-            device: Device ID.
-
-        """
-        return self.encoders.reset_streaming_cache(left_context, device)
-
-    def forward(
-        self,
-        x: torch.Tensor,
-        x_len: torch.Tensor,
-    ) -> Tuple[torch.Tensor, torch.Tensor]:
-        """Encode input sequences.
-
-        Args:
-            x: Encoder input features. (B, T_in, F)
-            x_len: Encoder input features lengths. (B,)
-
-        Returns:
-           x: Encoder outputs. (B, T_out, D_enc)
-           x_len: Encoder outputs lenghts. (B,)
-
-        """
-        short_status, limit_size = check_short_utt(
-            self.embed.subsampling_factor, x.size(1)
-        )
-
-        if short_status:
-            raise TooShortUttError(
-                f"has {x.size(1)} frames and is too short for subsampling "
-                + f"(it needs more than {limit_size} frames), return empty results",
-                x.size(1),
-                limit_size,
-            )
-
-        mask = make_source_mask(x_len)
-
-        if self.unified_model_training:
-            chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item()
-            x, mask = self.embed(x, mask, chunk_size)
-            pos_enc = self.pos_enc(x)
-            chunk_mask = make_chunk_mask(
-                x.size(1),
-                chunk_size,
-                left_chunk_size=self.left_chunk_size,
-                device=x.device,
-            )
-            x_utt = self.encoders(
-                x,
-                pos_enc,
-                mask,
-                chunk_mask=None,
-            )
-            x_chunk = self.encoders(
-                x,
-                pos_enc,
-                mask,
-                chunk_mask=chunk_mask,
-            )
-
-            olens = mask.eq(0).sum(1)
-            if self.time_reduction_factor > 1:
-                x_utt = x_utt[:,::self.time_reduction_factor,:]
-                x_chunk = x_chunk[:,::self.time_reduction_factor,:]
-                olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
-
-            return x_utt, x_chunk, olens
-
-        elif self.dynamic_chunk_training:
-            max_len = x.size(1)
-            chunk_size = torch.randint(1, max_len, (1,)).item()
-
-            if chunk_size > (max_len * self.short_chunk_threshold):
-                chunk_size = max_len
-            else:
-                chunk_size = (chunk_size % self.short_chunk_size) + 1
-
-            x, mask = self.embed(x, mask, chunk_size)
-            pos_enc = self.pos_enc(x)
-
-            chunk_mask = make_chunk_mask(
-                x.size(1),
-                chunk_size,
-                left_chunk_size=self.left_chunk_size,
-                device=x.device,
-            )
-        else:
-            x, mask = self.embed(x, mask, None)
-            pos_enc = self.pos_enc(x)
-            chunk_mask = None
-        x = self.encoders(
-            x,
-            pos_enc,
-            mask,
-            chunk_mask=chunk_mask,
-        )
-
-        olens = mask.eq(0).sum(1)
-        if self.time_reduction_factor > 1:
-            x = x[:,::self.time_reduction_factor,:]
-            olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
-
-        return x, olens
-
-    def simu_chunk_forward(
-        self,
-        x: torch.Tensor,
-        x_len: torch.Tensor,
-        chunk_size: int = 16,
-        left_context: int = 32,
-        right_context: int = 0,
-    ) -> torch.Tensor:
-        short_status, limit_size = check_short_utt(
-            self.embed.subsampling_factor, x.size(1)
-        )
-
-        if short_status:
-            raise TooShortUttError(
-                f"has {x.size(1)} frames and is too short for subsampling "
-                + f"(it needs more than {limit_size} frames), return empty results",
-                x.size(1),
-                limit_size,
-            )
-
-        mask = make_source_mask(x_len)
-
-        x, mask = self.embed(x, mask, chunk_size)
-        pos_enc = self.pos_enc(x)
-        chunk_mask = make_chunk_mask(
-            x.size(1),
-            chunk_size,
-            left_chunk_size=self.left_chunk_size,
-            device=x.device,
-        )
-
-        x = self.encoders(
-            x,
-            pos_enc,
-            mask,
-            chunk_mask=chunk_mask,
-        )
-        olens = mask.eq(0).sum(1)
-        if self.time_reduction_factor > 1:
-            x = x[:,::self.time_reduction_factor,:]
-
-        return x
-
-    def chunk_forward(
-        self,
-        x: torch.Tensor,
-        x_len: torch.Tensor,
-        processed_frames: torch.tensor,
-        chunk_size: int = 16,
-        left_context: int = 32,
-        right_context: int = 0,
-    ) -> torch.Tensor:
-        """Encode input sequences as chunks.
-
-        Args:
-            x: Encoder input features. (1, T_in, F)
-            x_len: Encoder input features lengths. (1,)
-            processed_frames: Number of frames already seen.
-            left_context: Number of frames in left context.
-            right_context: Number of frames in right context.
-
-        Returns:
-           x: Encoder outputs. (B, T_out, D_enc)
-
-        """
-        mask = make_source_mask(x_len)
-        x, mask = self.embed(x, mask, None)
-
-        if left_context > 0:
-            processed_mask = (
-                torch.arange(left_context, device=x.device)
-                .view(1, left_context)
-                .flip(1)
-            )
-            processed_mask = processed_mask >= processed_frames
-            mask = torch.cat([processed_mask, mask], dim=1)
-        pos_enc = self.pos_enc(x, left_context=left_context)
-        x = self.encoders.chunk_forward(
-            x,
-            pos_enc,
-            mask,
-            chunk_size=chunk_size,
-            left_context=left_context,
-            right_context=right_context,
-        )
-
-        if right_context > 0:
-            x = x[:, 0:-right_context, :]
-
-        if self.time_reduction_factor > 1:
-            x = x[:,::self.time_reduction_factor,:]
-        return x
diff --git a/funasr/models/encoder/chunk_encoder_blocks/__init__.py b/funasr/models/encoder/chunk_encoder_blocks/__init__.py
deleted file mode 100644
index e69de29..0000000
--- a/funasr/models/encoder/chunk_encoder_blocks/__init__.py
+++ /dev/null
diff --git a/funasr/models/encoder/chunk_encoder_blocks/branchformer.py b/funasr/models/encoder/chunk_encoder_blocks/branchformer.py
deleted file mode 100644
index ba0b25d..0000000
--- a/funasr/models/encoder/chunk_encoder_blocks/branchformer.py
+++ /dev/null
@@ -1,178 +0,0 @@
-"""Branchformer block for Transducer encoder."""
-
-from typing import Dict, Optional, Tuple
-
-import torch
-
-
-class Branchformer(torch.nn.Module):
-    """Branchformer module definition.
-
-    Reference: https://arxiv.org/pdf/2207.02971.pdf
-
-    Args:
-        block_size: Input/output size.
-        linear_size: Linear layers' hidden size.
-        self_att: Self-attention module instance.
-        conv_mod: Convolution module instance.
-        norm_class: Normalization class.
-        norm_args: Normalization module arguments.
-        dropout_rate: Dropout rate.
-
-    """
-
-    def __init__(
-        self,
-        block_size: int,
-        linear_size: int,
-        self_att: torch.nn.Module,
-        conv_mod: torch.nn.Module,
-        norm_class: torch.nn.Module = torch.nn.LayerNorm,
-        norm_args: Dict = {},
-        dropout_rate: float = 0.0,
-    ) -> None:
-        """Construct a Branchformer object."""
-        super().__init__()
-
-        self.self_att = self_att
-        self.conv_mod = conv_mod
-
-        self.channel_proj1 = torch.nn.Sequential(
-            torch.nn.Linear(block_size, linear_size), torch.nn.GELU()
-        )
-        self.channel_proj2 = torch.nn.Linear(linear_size // 2, block_size)
-
-        self.merge_proj = torch.nn.Linear(block_size + block_size, block_size)
-
-        self.norm_self_att = norm_class(block_size, **norm_args)
-        self.norm_mlp = norm_class(block_size, **norm_args)
-        self.norm_final = norm_class(block_size, **norm_args)
-
-        self.dropout = torch.nn.Dropout(dropout_rate)
-
-        self.block_size = block_size
-        self.linear_size = linear_size
-        self.cache = None
-
-    def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
-        """Initialize/Reset self-attention and convolution modules cache for streaming.
-
-        Args:
-            left_context: Number of left frames during chunk-by-chunk inference.
-            device: Device to use for cache tensor.
-
-        """
-        self.cache = [
-            torch.zeros(
-                (1, left_context, self.block_size),
-                device=device,
-            ),
-            torch.zeros(
-                (
-                    1,
-                    self.linear_size // 2,
-                    self.conv_mod.kernel_size - 1,
-                ),
-                device=device,
-            ),
-        ]
-
-    def forward(
-        self,
-        x: torch.Tensor,
-        pos_enc: torch.Tensor,
-        mask: torch.Tensor,
-        chunk_mask: Optional[torch.Tensor] = None,
-    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
-        """Encode input sequences.
-
-        Args:
-            x: Branchformer input sequences. (B, T, D_block)
-            pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
-            mask: Source mask. (B, T)
-            chunk_mask: Chunk mask. (T_2, T_2)
-
-        Returns:
-            x: Branchformer output sequences. (B, T, D_block)
-            mask: Source mask. (B, T)
-            pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
-
-        """
-        x1 = x
-        x2 = x
-
-        x1 = self.norm_self_att(x1)
-
-        x1 = self.dropout(
-            self.self_att(x1, x1, x1, pos_enc, mask=mask, chunk_mask=chunk_mask)
-        )
-
-        x2 = self.norm_mlp(x2)
-
-        x2 = self.channel_proj1(x2)
-        x2, _ = self.conv_mod(x2)
-        x2 = self.channel_proj2(x2)
-
-        x2 = self.dropout(x2)
-
-        x = x + self.dropout(self.merge_proj(torch.cat([x1, x2], dim=-1)))
-
-        x = self.norm_final(x)
-
-        return x, mask, pos_enc
-
-    def chunk_forward(
-        self,
-        x: torch.Tensor,
-        pos_enc: torch.Tensor,
-        mask: torch.Tensor,
-        left_context: int = 0,
-        right_context: int = 0,
-    ) -> Tuple[torch.Tensor, torch.Tensor]:
-        """Encode chunk of input sequence.
-
-        Args:
-            x: Branchformer input sequences. (B, T, D_block)
-            pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
-            mask: Source mask. (B, T_2)
-            left_context: Number of frames in left context.
-            right_context: Number of frames in right context.
-
-        Returns:
-            x: Branchformer output sequences. (B, T, D_block)
-            pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
-
-        """
-        x1 = x
-        x2 = x
-
-        x1 = self.norm_self_att(x1)
-
-        if left_context > 0:
-            key = torch.cat([self.cache[0], x1], dim=1)
-        else:
-            key = x1
-        val = key
-
-        if right_context > 0:
-            att_cache = key[:, -(left_context + right_context) : -right_context, :]
-        else:
-            att_cache = key[:, -left_context:, :]
-
-        x1 = self.self_att(x1, key, val, pos_enc, mask=mask, left_context=left_context)
-
-        x2 = self.norm_mlp(x2)
-        x2 = self.channel_proj1(x2)
-
-        x2, conv_cache = self.conv_mod(
-            x2, cache=self.cache[1], right_context=right_context
-        )
-
-        x2 = self.channel_proj2(x2)
-
-        x = x + self.merge_proj(torch.cat([x1, x2], dim=-1))
-
-        x = self.norm_final(x)
-        self.cache = [att_cache, conv_cache]
-
-        return x, pos_enc
diff --git a/funasr/models/encoder/chunk_encoder_blocks/conformer.py b/funasr/models/encoder/chunk_encoder_blocks/conformer.py
deleted file mode 100644
index 0b9bbbf..0000000
--- a/funasr/models/encoder/chunk_encoder_blocks/conformer.py
+++ /dev/null
@@ -1,198 +0,0 @@
-"""Conformer block for Transducer encoder."""
-
-from typing import Dict, Optional, Tuple
-
-import torch
-
-
-class Conformer(torch.nn.Module):
-    """Conformer module definition.
-
-    Args:
-        block_size: Input/output size.
-        self_att: Self-attention module instance.
-        feed_forward: Feed-forward module instance.
-        feed_forward_macaron: Feed-forward module instance for macaron network.
-        conv_mod: Convolution module instance.
-        norm_class: Normalization module class.
-        norm_args: Normalization module arguments.
-        dropout_rate: Dropout rate.
-
-    """
-
-    def __init__(
-        self,
-        block_size: int,
-        self_att: torch.nn.Module,
-        feed_forward: torch.nn.Module,
-        feed_forward_macaron: torch.nn.Module,
-        conv_mod: torch.nn.Module,
-        norm_class: torch.nn.Module = torch.nn.LayerNorm,
-        norm_args: Dict = {},
-        dropout_rate: float = 0.0,
-    ) -> None:
-        """Construct a Conformer object."""
-        super().__init__()
-
-        self.self_att = self_att
-
-        self.feed_forward = feed_forward
-        self.feed_forward_macaron = feed_forward_macaron
-        self.feed_forward_scale = 0.5
-
-        self.conv_mod = conv_mod
-
-        self.norm_feed_forward = norm_class(block_size, **norm_args)
-        self.norm_self_att = norm_class(block_size, **norm_args)
-
-        self.norm_macaron = norm_class(block_size, **norm_args)
-        self.norm_conv = norm_class(block_size, **norm_args)
-        self.norm_final = norm_class(block_size, **norm_args)
-
-        self.dropout = torch.nn.Dropout(dropout_rate)
-
-        self.block_size = block_size
-        self.cache = None
-
-    def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
-        """Initialize/Reset self-attention and convolution modules cache for streaming.
-
-        Args:
-            left_context: Number of left frames during chunk-by-chunk inference.
-            device: Device to use for cache tensor.
-
-        """
-        self.cache = [
-            torch.zeros(
-                (1, left_context, self.block_size),
-                device=device,
-            ),
-            torch.zeros(
-                (
-                    1,
-                    self.block_size,
-                    self.conv_mod.kernel_size - 1,
-                ),
-                device=device,
-            ),
-        ]
-
-    def forward(
-        self,
-        x: torch.Tensor,
-        pos_enc: torch.Tensor,
-        mask: torch.Tensor,
-        chunk_mask: Optional[torch.Tensor] = None,
-    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
-        """Encode input sequences.
-
-        Args:
-            x: Conformer input sequences. (B, T, D_block)
-            pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
-            mask: Source mask. (B, T)
-            chunk_mask: Chunk mask. (T_2, T_2)
-
-        Returns:
-            x: Conformer output sequences. (B, T, D_block)
-            mask: Source mask. (B, T)
-            pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
-
-        """
-        residual = x
-
-        x = self.norm_macaron(x)
-        x = residual + self.feed_forward_scale * self.dropout(
-            self.feed_forward_macaron(x)
-        )
-
-        residual = x
-        x = self.norm_self_att(x)
-        x_q = x
-        x = residual + self.dropout(
-            self.self_att(
-                x_q,
-                x,
-                x,
-                pos_enc,
-                mask,
-                chunk_mask=chunk_mask,
-            )
-        )
-
-        residual = x
-
-        x = self.norm_conv(x)
-        x, _ = self.conv_mod(x)
-        x = residual + self.dropout(x)
-        residual = x
-
-        x = self.norm_feed_forward(x)
-        x = residual + self.feed_forward_scale * self.dropout(self.feed_forward(x))
-
-        x = self.norm_final(x)
-        return x, mask, pos_enc
-
-    def chunk_forward(
-        self,
-        x: torch.Tensor,
-        pos_enc: torch.Tensor,
-        mask: torch.Tensor,
-        chunk_size: int = 16,
-        left_context: int = 0,
-        right_context: int = 0,
-    ) -> Tuple[torch.Tensor, torch.Tensor]:
-        """Encode chunk of input sequence.
-
-        Args:
-            x: Conformer input sequences. (B, T, D_block)
-            pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
-            mask: Source mask. (B, T_2)
-            left_context: Number of frames in left context.
-            right_context: Number of frames in right context.
-
-        Returns:
-            x: Conformer output sequences. (B, T, D_block)
-            pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
-
-        """
-        residual = x
-
-        x = self.norm_macaron(x)
-        x = residual + self.feed_forward_scale * self.feed_forward_macaron(x)
-
-        residual = x
-        x = self.norm_self_att(x)
-        if left_context > 0:
-            key = torch.cat([self.cache[0], x], dim=1)
-        else:
-            key = x
-        val = key
-
-        if right_context > 0:
-            att_cache = key[:, -(left_context + right_context) : -right_context, :]
-        else:
-            att_cache = key[:, -left_context:, :]
-        x = residual + self.self_att(
-            x,
-            key,
-            val,
-            pos_enc,
-            mask,
-            left_context=left_context,
-        )
-
-        residual = x
-        x = self.norm_conv(x)
-        x, conv_cache = self.conv_mod(
-            x, cache=self.cache[1], right_context=right_context
-        )
-        x = residual + x
-        residual = x
-        
-        x = self.norm_feed_forward(x)
-        x = residual + self.feed_forward_scale * self.feed_forward(x)
-
-        x = self.norm_final(x)
-        self.cache = [att_cache, conv_cache]
-       
-        return x, pos_enc
diff --git a/funasr/models/encoder/chunk_encoder_blocks/conv1d.py b/funasr/models/encoder/chunk_encoder_blocks/conv1d.py
deleted file mode 100644
index f79cc37..0000000
--- a/funasr/models/encoder/chunk_encoder_blocks/conv1d.py
+++ /dev/null
@@ -1,221 +0,0 @@
-"""Conv1d block for Transducer encoder."""
-
-from typing import Optional, Tuple, Union
-
-import torch
-
-
-class Conv1d(torch.nn.Module):
-    """Conv1d module definition.
-
-    Args:
-        input_size: Input dimension.
-        output_size: Output dimension.
-        kernel_size: Size of the convolving kernel.
-        stride: Stride of the convolution.
-        dilation: Spacing between the kernel points.
-        groups: Number of blocked connections from input channels to output channels.
-        bias: Whether to add a learnable bias to the output.
-        batch_norm: Whether to use batch normalization after convolution.
-        relu: Whether to use a ReLU activation after convolution.
-        causal: Whether to use causal convolution (set to True if streaming).
-        dropout_rate: Dropout rate.
-
-    """
-
-    def __init__(
-        self,
-        input_size: int,
-        output_size: int,
-        kernel_size: Union[int, Tuple],
-        stride: Union[int, Tuple] = 1,
-        dilation: Union[int, Tuple] = 1,
-        groups: Union[int, Tuple] = 1,
-        bias: bool = True,
-        batch_norm: bool = False,
-        relu: bool = True,
-        causal: bool = False,
-        dropout_rate: float = 0.0,
-    ) -> None:
-        """Construct a Conv1d object."""
-        super().__init__()
-
-        if causal:
-            self.lorder = kernel_size - 1
-            stride = 1
-        else:
-            self.lorder = 0
-            stride = stride
-
-        self.conv = torch.nn.Conv1d(
-            input_size,
-            output_size,
-            kernel_size,
-            stride=stride,
-            dilation=dilation,
-            groups=groups,
-            bias=bias,
-        )
-
-        self.dropout = torch.nn.Dropout(p=dropout_rate)
-
-        if relu:
-            self.relu_func = torch.nn.ReLU()
-
-        if batch_norm:
-            self.bn = torch.nn.BatchNorm1d(output_size)
-
-        self.out_pos = torch.nn.Linear(input_size, output_size)
-
-        self.input_size = input_size
-        self.output_size = output_size
-
-        self.relu = relu
-        self.batch_norm = batch_norm
-        self.causal = causal
-
-        self.kernel_size = kernel_size
-        self.padding = dilation * (kernel_size - 1)
-        self.stride = stride
-
-        self.cache = None
-
-    def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
-        """Initialize/Reset Conv1d cache for streaming.
-
-        Args:
-            left_context: Number of left frames during chunk-by-chunk inference.
-            device: Device to use for cache tensor.
-
-        """
-        self.cache = torch.zeros(
-            (1, self.input_size, self.kernel_size - 1), device=device
-        )
-
-    def forward(
-        self,
-        x: torch.Tensor,
-        pos_enc: torch.Tensor,
-        mask: Optional[torch.Tensor] = None,
-        chunk_mask: Optional[torch.Tensor] = None,
-    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
-        """Encode input sequences.
-
-        Args:
-            x: Conv1d input sequences. (B, T, D_in)
-            pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_in)
-            mask: Source mask. (B, T)
-            chunk_mask: Chunk mask. (T_2, T_2)
-
-        Returns:
-            x: Conv1d output sequences. (B, sub(T), D_out)
-            mask: Source mask. (B, T) or (B, sub(T))
-            pos_enc: Positional embedding sequences.
-                       (B, 2 * (T - 1), D_att) or (B, 2 * (sub(T) - 1), D_out)
-
-        """
-        x = x.transpose(1, 2)
-
-        if self.lorder > 0:
-            x = torch.nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
-        else:
-            mask = self.create_new_mask(mask)
-            pos_enc = self.create_new_pos_enc(pos_enc)
-
-        x = self.conv(x)
-
-        if self.batch_norm:
-            x = self.bn(x)
-
-        x = self.dropout(x)
-
-        if self.relu:
-            x = self.relu_func(x)
-
-        x = x.transpose(1, 2)
-
-        return x, mask, self.out_pos(pos_enc)
-
-    def chunk_forward(
-        self,
-        x: torch.Tensor,
-        pos_enc: torch.Tensor,
-        mask: torch.Tensor,
-        left_context: int = 0,
-        right_context: int = 0,
-    ) -> Tuple[torch.Tensor, torch.Tensor]:
-        """Encode chunk of input sequence.
-
-        Args:
-            x: Conv1d input sequences. (B, T, D_in)
-            pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_in)
-            mask: Source mask. (B, T)
-            left_context: Number of frames in left context.
-            right_context: Number of frames in right context.
-
-        Returns:
-            x: Conv1d output sequences. (B, T, D_out)
-            pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_out)
-
-        """
-        x = torch.cat([self.cache, x.transpose(1, 2)], dim=2)
-
-        if right_context > 0:
-            self.cache = x[:, :, -(self.lorder + right_context) : -right_context]
-        else:
-            self.cache = x[:, :, -self.lorder :]
-
-        x = self.conv(x)
-
-        if self.batch_norm:
-            x = self.bn(x)
-
-        x = self.dropout(x)
-
-        if self.relu:
-            x = self.relu_func(x)
-
-        x = x.transpose(1, 2)
-
-        return x, self.out_pos(pos_enc)
-
-    def create_new_mask(self, mask: torch.Tensor) -> torch.Tensor:
-        """Create new mask for output sequences.
-
-        Args:
-            mask: Mask of input sequences. (B, T)
-
-        Returns:
-            mask: Mask of output sequences. (B, sub(T))
-
-        """
-        if self.padding != 0:
-            mask = mask[:, : -self.padding]
-
-        return mask[:, :: self.stride]
-
-    def create_new_pos_enc(self, pos_enc: torch.Tensor) -> torch.Tensor:
-        """Create new positional embedding vector.
-
-        Args:
-            pos_enc: Input sequences positional embedding.
-                     (B, 2 * (T - 1), D_in)
-
-        Returns:
-            pos_enc: Output sequences positional embedding.
-                     (B, 2 * (sub(T) - 1), D_in)
-
-        """
-        pos_enc_positive = pos_enc[:, : pos_enc.size(1) // 2 + 1, :]
-        pos_enc_negative = pos_enc[:, pos_enc.size(1) // 2 :, :]
-
-        if self.padding != 0:
-            pos_enc_positive = pos_enc_positive[:, : -self.padding, :]
-            pos_enc_negative = pos_enc_negative[:, : -self.padding, :]
-
-        pos_enc_positive = pos_enc_positive[:, :: self.stride, :]
-        pos_enc_negative = pos_enc_negative[:, :: self.stride, :]
-
-        pos_enc = torch.cat([pos_enc_positive, pos_enc_negative[:, 1:, :]], dim=1)
-
-        return pos_enc
diff --git a/funasr/models/encoder/chunk_encoder_blocks/conv_input.py b/funasr/models/encoder/chunk_encoder_blocks/conv_input.py
deleted file mode 100644
index b9bd2fd..0000000
--- a/funasr/models/encoder/chunk_encoder_blocks/conv_input.py
+++ /dev/null
@@ -1,222 +0,0 @@
-"""ConvInput block for Transducer encoder."""
-
-from typing import Optional, Tuple, Union
-
-import torch
-import math
-
-from funasr.modules.nets_utils import sub_factor_to_params, pad_to_len
-
-
-class ConvInput(torch.nn.Module):
-    """ConvInput module definition.
-
-    Args:
-        input_size: Input size.
-        conv_size: Convolution size.
-        subsampling_factor: Subsampling factor.
-        vgg_like: Whether to use a VGG-like network.
-        output_size: Block output dimension.
-
-    """
-
-    def __init__(
-        self,
-        input_size: int,
-        conv_size: Union[int, Tuple],
-        subsampling_factor: int = 4,
-        vgg_like: bool = True,
-        output_size: Optional[int] = None,
-    ) -> None:
-        """Construct a ConvInput object."""
-        super().__init__()
-        if vgg_like:
-            if subsampling_factor == 1:
-                conv_size1, conv_size2 = conv_size
-
-                self.conv = torch.nn.Sequential(
-                    torch.nn.Conv2d(1, conv_size1, 3, stride=1, padding=1),
-                    torch.nn.ReLU(),
-                    torch.nn.Conv2d(conv_size1, conv_size1, 3, stride=1, padding=1),
-                    torch.nn.ReLU(),
-                    torch.nn.MaxPool2d((1, 2)),
-                    torch.nn.Conv2d(conv_size1, conv_size2, 3, stride=1, padding=1),
-                    torch.nn.ReLU(),
-                    torch.nn.Conv2d(conv_size2, conv_size2, 3, stride=1, padding=1),
-                    torch.nn.ReLU(),
-                    torch.nn.MaxPool2d((1, 2)),
-                )
-
-                output_proj = conv_size2 * ((input_size // 2) // 2)
- 
-                self.subsampling_factor = 1
-
-                self.stride_1 = 1
-
-                self.create_new_mask = self.create_new_vgg_mask
-
-            else:
-                conv_size1, conv_size2 = conv_size
-
-                kernel_1 = int(subsampling_factor / 2)
-
-                self.conv = torch.nn.Sequential(
-                    torch.nn.Conv2d(1, conv_size1, 3, stride=1, padding=1),
-                    torch.nn.ReLU(),
-                    torch.nn.Conv2d(conv_size1, conv_size1, 3, stride=1, padding=1),
-                    torch.nn.ReLU(),
-                    torch.nn.MaxPool2d((kernel_1, 2)),
-                    torch.nn.Conv2d(conv_size1, conv_size2, 3, stride=1, padding=1),
-                    torch.nn.ReLU(),
-                    torch.nn.Conv2d(conv_size2, conv_size2, 3, stride=1, padding=1),
-                    torch.nn.ReLU(),
-                    torch.nn.MaxPool2d((2, 2)),
-                )
-
-                output_proj = conv_size2 * ((input_size // 2) // 2)
-
-                self.subsampling_factor = subsampling_factor
-
-                self.create_new_mask = self.create_new_vgg_mask
-                
-                self.stride_1 = kernel_1
-
-        else:
-            if subsampling_factor == 1:
-                self.conv = torch.nn.Sequential(
-                    torch.nn.Conv2d(1, conv_size, 3, [1,2], [1,0]),
-                    torch.nn.ReLU(),
-                    torch.nn.Conv2d(conv_size, conv_size, 3, [1,2], [1,0]),
-                    torch.nn.ReLU(),
-                )
-
-                output_proj = conv_size * (((input_size - 1) // 2 - 1) // 2)
-
-                self.subsampling_factor = subsampling_factor
-                self.kernel_2 = 3
-                self.stride_2 = 1
-
-                self.create_new_mask = self.create_new_conv2d_mask
-
-            else:
-                kernel_2, stride_2, conv_2_output_size = sub_factor_to_params(
-                    subsampling_factor,
-                    input_size,
-                )
-
-                self.conv = torch.nn.Sequential(
-                    torch.nn.Conv2d(1, conv_size, 3, 2),
-                    torch.nn.ReLU(),
-                    torch.nn.Conv2d(conv_size, conv_size, kernel_2, stride_2),
-                    torch.nn.ReLU(),
-                )
-
-                output_proj = conv_size * conv_2_output_size
-
-                self.subsampling_factor = subsampling_factor
-                self.kernel_2 = kernel_2
-                self.stride_2 = stride_2
-
-                self.create_new_mask = self.create_new_conv2d_mask
-
-        self.vgg_like = vgg_like
-        self.min_frame_length = 7
-
-        if output_size is not None:
-            self.output = torch.nn.Linear(output_proj, output_size)
-            self.output_size = output_size
-        else:
-            self.output = None
-            self.output_size = output_proj
-
-    def forward(
-        self, x: torch.Tensor, mask: Optional[torch.Tensor], chunk_size: Optional[torch.Tensor]
-    ) -> Tuple[torch.Tensor, torch.Tensor]:
-        """Encode input sequences.
-
-        Args:
-            x: ConvInput input sequences. (B, T, D_feats)
-            mask: Mask of input sequences. (B, 1, T)
-
-        Returns:
-            x: ConvInput output sequences. (B, sub(T), D_out)
-            mask: Mask of output sequences. (B, 1, sub(T))
-
-        """
-        if mask is not None:
-            mask = self.create_new_mask(mask)
-            olens = max(mask.eq(0).sum(1))
-
-        b, t, f = x.size()
-        x = x.unsqueeze(1) # (b. 1. t. f)
-
-        if chunk_size is not None:
-            max_input_length = int(
-                chunk_size * self.subsampling_factor * (math.ceil(float(t) / (chunk_size * self.subsampling_factor) ))
-            )
-            x = map(lambda inputs: pad_to_len(inputs, max_input_length, 1), x)
-            x = list(x)
-            x = torch.stack(x, dim=0)
-            N_chunks = max_input_length // ( chunk_size * self.subsampling_factor)
-            x = x.view(b * N_chunks, 1, chunk_size * self.subsampling_factor, f)
-
-        x = self.conv(x)
-
-        _, c, _, f = x.size()
-        if chunk_size is not None:
-            x = x.transpose(1, 2).contiguous().view(b, -1, c * f)[:,:olens,:]
-        else:
-            x = x.transpose(1, 2).contiguous().view(b, -1, c * f)
-
-        if self.output is not None:
-            x = self.output(x)
-
-        return x, mask[:,:olens][:,:x.size(1)]
-
-    def create_new_vgg_mask(self, mask: torch.Tensor) -> torch.Tensor:
-        """Create a new mask for VGG output sequences.
-
-        Args:
-            mask: Mask of input sequences. (B, T)
-
-        Returns:
-            mask: Mask of output sequences. (B, sub(T))
-
-        """
-        if self.subsampling_factor > 1:
-            vgg1_t_len = mask.size(1) - (mask.size(1) % (self.subsampling_factor // 2 ))
-            mask = mask[:, :vgg1_t_len][:, ::self.subsampling_factor // 2]
-
-            vgg2_t_len = mask.size(1) - (mask.size(1) % 2)
-            mask = mask[:, :vgg2_t_len][:, ::2]
-        else:
-            mask = mask
-
-        return mask
-
-    def create_new_conv2d_mask(self, mask: torch.Tensor) -> torch.Tensor:
-        """Create new conformer mask for Conv2d output sequences.
-
-        Args:
-            mask: Mask of input sequences. (B, T)
-
-        Returns:
-            mask: Mask of output sequences. (B, sub(T))
-
-        """
-        if self.subsampling_factor > 1:
-            return mask[:, :-2:2][:, : -(self.kernel_2 - 1) : self.stride_2]
-        else:
-            return mask
-
-    def get_size_before_subsampling(self, size: int) -> int:
-        """Return the original size before subsampling for a given size.
-
-        Args:
-            size: Number of frames after subsampling.
-
-        Returns:
-            : Number of frames before subsampling.
-
-        """
-        return size * self.subsampling_factor
diff --git a/funasr/models/encoder/chunk_encoder_blocks/linear_input.py b/funasr/models/encoder/chunk_encoder_blocks/linear_input.py
deleted file mode 100644
index 9bb9698..0000000
--- a/funasr/models/encoder/chunk_encoder_blocks/linear_input.py
+++ /dev/null
@@ -1,52 +0,0 @@
-"""LinearInput block for Transducer encoder."""
-
-from typing import Optional, Tuple, Union
-
-import torch
-
-class LinearInput(torch.nn.Module):
-    """ConvInput module definition.
-
-    Args:
-        input_size: Input size.
-        conv_size: Convolution size.
-        subsampling_factor: Subsampling factor.
-        vgg_like: Whether to use a VGG-like network.
-        output_size: Block output dimension.
-
-    """
-
-    def __init__(
-        self,
-        input_size: int,
-        output_size: Optional[int] = None,
-        subsampling_factor: int = 1,
-    ) -> None:
-        """Construct a ConvInput object."""
-        super().__init__()
-        self.embed = torch.nn.Sequential(
-            torch.nn.Linear(input_size, output_size),
-            torch.nn.LayerNorm(output_size),
-            torch.nn.Dropout(0.1),
-        )
-        self.subsampling_factor = subsampling_factor
-        self.min_frame_length = 1
-
-    def forward(
-        self, x: torch.Tensor, mask: Optional[torch.Tensor]
-    ) -> Tuple[torch.Tensor, torch.Tensor]:
-        
-        x = self.embed(x)
-        return x, mask
-
-    def get_size_before_subsampling(self, size: int) -> int:
-        """Return the original size before subsampling for a given size.
-
-        Args:
-            size: Number of frames after subsampling.
-
-        Returns:
-            : Number of frames before subsampling.
-
-        """
-        return size
diff --git a/funasr/models/encoder/chunk_encoder_modules/__init__.py b/funasr/models/encoder/chunk_encoder_modules/__init__.py
deleted file mode 100644
index e69de29..0000000
--- a/funasr/models/encoder/chunk_encoder_modules/__init__.py
+++ /dev/null
diff --git a/funasr/models/encoder/chunk_encoder_modules/attention.py b/funasr/models/encoder/chunk_encoder_modules/attention.py
deleted file mode 100644
index 53e7087..0000000
--- a/funasr/models/encoder/chunk_encoder_modules/attention.py
+++ /dev/null
@@ -1,246 +0,0 @@
-"""Multi-Head attention layers with relative positional encoding."""
-
-import math
-from typing import Optional, Tuple
-
-import torch
-
-
-class RelPositionMultiHeadedAttention(torch.nn.Module):
-    """RelPositionMultiHeadedAttention definition.
-
-    Args:
-        num_heads: Number of attention heads.
-        embed_size: Embedding size.
-        dropout_rate: Dropout rate.
-
-    """
-
-    def __init__(
-        self,
-        num_heads: int,
-        embed_size: int,
-        dropout_rate: float = 0.0,
-        simplified_attention_score: bool = False,
-    ) -> None:
-        """Construct an MultiHeadedAttention object."""
-        super().__init__()
-
-        self.d_k = embed_size // num_heads
-        self.num_heads = num_heads
-
-        assert self.d_k * num_heads == embed_size, (
-            "embed_size (%d) must be divisible by num_heads (%d)",
-            (embed_size, num_heads),
-        )
-
-        self.linear_q = torch.nn.Linear(embed_size, embed_size)
-        self.linear_k = torch.nn.Linear(embed_size, embed_size)
-        self.linear_v = torch.nn.Linear(embed_size, embed_size)
-
-        self.linear_out = torch.nn.Linear(embed_size, embed_size)
-
-        if simplified_attention_score:
-            self.linear_pos = torch.nn.Linear(embed_size, num_heads)
-
-            self.compute_att_score = self.compute_simplified_attention_score
-        else:
-            self.linear_pos = torch.nn.Linear(embed_size, embed_size, bias=False)
-
-            self.pos_bias_u = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k))
-            self.pos_bias_v = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k))
-            torch.nn.init.xavier_uniform_(self.pos_bias_u)
-            torch.nn.init.xavier_uniform_(self.pos_bias_v)
-
-            self.compute_att_score = self.compute_attention_score
-
-        self.dropout = torch.nn.Dropout(p=dropout_rate)
-        self.attn = None
-
-    def rel_shift(self, x: torch.Tensor, left_context: int = 0) -> torch.Tensor:
-        """Compute relative positional encoding.
-
-        Args:
-            x: Input sequence. (B, H, T_1, 2 * T_1 - 1)
-            left_context: Number of frames in left context.
-
-        Returns:
-            x: Output sequence. (B, H, T_1, T_2)
-
-        """
-        batch_size, n_heads, time1, n = x.shape
-        time2 = time1 + left_context
-
-        batch_stride, n_heads_stride, time1_stride, n_stride = x.stride()
-
-        return x.as_strided(
-            (batch_size, n_heads, time1, time2),
-            (batch_stride, n_heads_stride, time1_stride - n_stride, n_stride),
-            storage_offset=(n_stride * (time1 - 1)),
-        )
-
-    def compute_simplified_attention_score(
-        self,
-        query: torch.Tensor,
-        key: torch.Tensor,
-        pos_enc: torch.Tensor,
-        left_context: int = 0,
-    ) -> torch.Tensor:
-        """Simplified attention score computation.
-
-        Reference: https://github.com/k2-fsa/icefall/pull/458
-
-        Args:
-            query: Transformed query tensor. (B, H, T_1, d_k)
-            key: Transformed key tensor. (B, H, T_2, d_k)
-            pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
-            left_context: Number of frames in left context.
-
-        Returns:
-            : Attention score. (B, H, T_1, T_2)
-
-        """
-        pos_enc = self.linear_pos(pos_enc)
-
-        matrix_ac = torch.matmul(query, key.transpose(2, 3))
-
-        matrix_bd = self.rel_shift(
-            pos_enc.transpose(1, 2).unsqueeze(2).repeat(1, 1, query.size(2), 1),
-            left_context=left_context,
-        )
-
-        return (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
-
-    def compute_attention_score(
-        self,
-        query: torch.Tensor,
-        key: torch.Tensor,
-        pos_enc: torch.Tensor,
-        left_context: int = 0,
-    ) -> torch.Tensor:
-        """Attention score computation.
-
-        Args:
-            query: Transformed query tensor. (B, H, T_1, d_k)
-            key: Transformed key tensor. (B, H, T_2, d_k)
-            pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
-            left_context: Number of frames in left context.
-
-        Returns:
-            : Attention score. (B, H, T_1, T_2)
-
-        """
-        p = self.linear_pos(pos_enc).view(pos_enc.size(0), -1, self.num_heads, self.d_k)
-
-        query = query.transpose(1, 2)
-        q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2)
-        q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2)
-
-        matrix_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1))
-
-        matrix_bd = torch.matmul(q_with_bias_v, p.permute(0, 2, 3, 1))
-        matrix_bd = self.rel_shift(matrix_bd, left_context=left_context)
-
-        return (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
-
-    def forward_qkv(
-        self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
-    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
-        """Transform query, key and value.
-
-        Args:
-            query: Query tensor. (B, T_1, size)
-            key: Key tensor. (B, T_2, size)
-            v: Value tensor. (B, T_2, size)
-
-        Returns:
-            q: Transformed query tensor. (B, H, T_1, d_k)
-            k: Transformed key tensor. (B, H, T_2, d_k)
-            v: Transformed value tensor. (B, H, T_2, d_k)
-
-        """
-        n_batch = query.size(0)
-
-        q = (
-            self.linear_q(query)
-            .view(n_batch, -1, self.num_heads, self.d_k)
-            .transpose(1, 2)
-        )
-        k = (
-            self.linear_k(key)
-            .view(n_batch, -1, self.num_heads, self.d_k)
-            .transpose(1, 2)
-        )
-        v = (
-            self.linear_v(value)
-            .view(n_batch, -1, self.num_heads, self.d_k)
-            .transpose(1, 2)
-        )
-
-        return q, k, v
-
-    def forward_attention(
-        self,
-        value: torch.Tensor,
-        scores: torch.Tensor,
-        mask: torch.Tensor,
-        chunk_mask: Optional[torch.Tensor] = None,
-    ) -> torch.Tensor:
-        """Compute attention context vector.
-
-        Args:
-            value: Transformed value. (B, H, T_2, d_k)
-            scores: Attention score. (B, H, T_1, T_2)
-            mask: Source mask. (B, T_2)
-            chunk_mask: Chunk mask. (T_1, T_1)
-
-        Returns:
-           attn_output: Transformed value weighted by attention score. (B, T_1, H * d_k)
-
-        """
-        batch_size = scores.size(0)
-        mask = mask.unsqueeze(1).unsqueeze(2)
-        if chunk_mask is not None:
-            mask = chunk_mask.unsqueeze(0).unsqueeze(1) | mask
-        scores = scores.masked_fill(mask, float("-inf"))
-        self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0)
-
-        attn_output = self.dropout(self.attn)
-        attn_output = torch.matmul(attn_output, value)
-
-        attn_output = self.linear_out(
-            attn_output.transpose(1, 2)
-            .contiguous()
-            .view(batch_size, -1, self.num_heads * self.d_k)
-        )
-
-        return attn_output
-
-    def forward(
-        self,
-        query: torch.Tensor,
-        key: torch.Tensor,
-        value: torch.Tensor,
-        pos_enc: torch.Tensor,
-        mask: torch.Tensor,
-        chunk_mask: Optional[torch.Tensor] = None,
-        left_context: int = 0,
-    ) -> torch.Tensor:
-        """Compute scaled dot product attention with rel. positional encoding.
-
-        Args:
-            query: Query tensor. (B, T_1, size)
-            key: Key tensor. (B, T_2, size)
-            value: Value tensor. (B, T_2, size)
-            pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
-            mask: Source mask. (B, T_2)
-            chunk_mask: Chunk mask. (T_1, T_1)
-            left_context: Number of frames in left context.
-
-        Returns:
-            : Output tensor. (B, T_1, H * d_k)
-
-        """
-        q, k, v = self.forward_qkv(query, key, value)
-        scores = self.compute_att_score(q, k, pos_enc, left_context=left_context)
-        return self.forward_attention(v, scores, mask, chunk_mask=chunk_mask)
diff --git a/funasr/models/encoder/chunk_encoder_modules/convolution.py b/funasr/models/encoder/chunk_encoder_modules/convolution.py
deleted file mode 100644
index 012538a..0000000
--- a/funasr/models/encoder/chunk_encoder_modules/convolution.py
+++ /dev/null
@@ -1,196 +0,0 @@
-"""Convolution modules for X-former blocks."""
-
-from typing import Dict, Optional, Tuple
-
-import torch
-
-
-class ConformerConvolution(torch.nn.Module):
-    """ConformerConvolution module definition.
-
-    Args:
-        channels: The number of channels.
-        kernel_size: Size of the convolving kernel.
-        activation: Type of activation function.
-        norm_args: Normalization module arguments.
-        causal: Whether to use causal convolution (set to True if streaming).
-
-    """
-
-    def __init__(
-        self,
-        channels: int,
-        kernel_size: int,
-        activation: torch.nn.Module = torch.nn.ReLU(),
-        norm_args: Dict = {},
-        causal: bool = False,
-    ) -> None:
-        """Construct an ConformerConvolution object."""
-        super().__init__()
-
-        assert (kernel_size - 1) % 2 == 0
-
-        self.kernel_size = kernel_size
-
-        self.pointwise_conv1 = torch.nn.Conv1d(
-            channels,
-            2 * channels,
-            kernel_size=1,
-            stride=1,
-            padding=0,
-        )
-
-        if causal:
-            self.lorder = kernel_size - 1
-            padding = 0
-        else:
-            self.lorder = 0
-            padding = (kernel_size - 1) // 2
-
-        self.depthwise_conv = torch.nn.Conv1d(
-            channels,
-            channels,
-            kernel_size,
-            stride=1,
-            padding=padding,
-            groups=channels,
-        )
-        self.norm = torch.nn.BatchNorm1d(channels, **norm_args)
-        self.pointwise_conv2 = torch.nn.Conv1d(
-            channels,
-            channels,
-            kernel_size=1,
-            stride=1,
-            padding=0,
-        )
-
-        self.activation = activation
-
-    def forward(
-        self,
-        x: torch.Tensor,
-        cache: Optional[torch.Tensor] = None,
-        right_context: int = 0,
-    ) -> Tuple[torch.Tensor, torch.Tensor]:
-        """Compute convolution module.
-
-        Args:
-            x: ConformerConvolution input sequences. (B, T, D_hidden)
-            cache: ConformerConvolution input cache. (1, conv_kernel, D_hidden)
-            right_context: Number of frames in right context.
-
-        Returns:
-            x: ConformerConvolution output sequences. (B, T, D_hidden)
-            cache: ConformerConvolution output cache. (1, conv_kernel, D_hidden)
-
-        """
-        x = self.pointwise_conv1(x.transpose(1, 2))
-        x = torch.nn.functional.glu(x, dim=1)
-
-        if self.lorder > 0:
-            if cache is None:
-                x = torch.nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
-            else:
-                x = torch.cat([cache, x], dim=2)
-
-                if right_context > 0:
-                    cache = x[:, :, -(self.lorder + right_context) : -right_context]
-                else:
-                    cache = x[:, :, -self.lorder :]
-
-        x = self.depthwise_conv(x)
-        x = self.activation(self.norm(x))
-
-        x = self.pointwise_conv2(x).transpose(1, 2)
-
-        return x, cache
-
-
-class ConvolutionalSpatialGatingUnit(torch.nn.Module):
-    """Convolutional Spatial Gating Unit module definition.
-
-    Args:
-        size: Initial size to determine the number of channels.
-        kernel_size: Size of the convolving kernel.
-        norm_class: Normalization module class.
-        norm_args: Normalization module arguments.
-        dropout_rate: Dropout rate.
-        causal: Whether to use causal convolution (set to True if streaming).
-
-    """
-
-    def __init__(
-        self,
-        size: int,
-        kernel_size: int,
-        norm_class: torch.nn.Module = torch.nn.LayerNorm,
-        norm_args: Dict = {},
-        dropout_rate: float = 0.0,
-        causal: bool = False,
-    ) -> None:
-        """Construct a ConvolutionalSpatialGatingUnit object."""
-        super().__init__()
-
-        channels = size // 2
-
-        self.kernel_size = kernel_size
-
-        if causal:
-            self.lorder = kernel_size - 1
-            padding = 0
-        else:
-            self.lorder = 0
-            padding = (kernel_size - 1) // 2
-
-        self.conv = torch.nn.Conv1d(
-            channels,
-            channels,
-            kernel_size,
-            stride=1,
-            padding=padding,
-            groups=channels,
-        )
-
-        self.norm = norm_class(channels, **norm_args)
-        self.activation = torch.nn.Identity()
-
-        self.dropout = torch.nn.Dropout(dropout_rate)
-
-    def forward(
-        self,
-        x: torch.Tensor,
-        cache: Optional[torch.Tensor] = None,
-        right_context: int = 0,
-    ) -> Tuple[torch.Tensor, torch.Tensor]:
-        """Compute convolution module.
-
-        Args:
-            x: ConvolutionalSpatialGatingUnit input sequences. (B, T, D_hidden)
-            cache: ConvolutionalSpationGatingUnit input cache.
-                   (1, conv_kernel, D_hidden)
-            right_context: Number of frames in right context.
-
-        Returns:
-            x: ConvolutionalSpatialGatingUnit output sequences. (B, T, D_hidden // 2)
-
-        """
-        x_r, x_g = x.chunk(2, dim=-1)
-
-        x_g = self.norm(x_g).transpose(1, 2)
-
-        if self.lorder > 0:
-            if cache is None:
-                x_g = torch.nn.functional.pad(x_g, (self.lorder, 0), "constant", 0.0)
-            else:
-                x_g = torch.cat([cache, x_g], dim=2)
-
-                if right_context > 0:
-                    cache = x_g[:, :, -(self.lorder + right_context) : -right_context]
-                else:
-                    cache = x_g[:, :, -self.lorder :]
-
-        x_g = self.conv(x_g).transpose(1, 2)
-
-        x = self.dropout(x_r * self.activation(x_g))
-
-        return x, cache
diff --git a/funasr/models/encoder/chunk_encoder_modules/multi_blocks.py b/funasr/models/encoder/chunk_encoder_modules/multi_blocks.py
deleted file mode 100644
index 14aca8b..0000000
--- a/funasr/models/encoder/chunk_encoder_modules/multi_blocks.py
+++ /dev/null
@@ -1,105 +0,0 @@
-"""MultiBlocks for encoder architecture."""
-
-from typing import Dict, List, Optional
-
-import torch
-
-
-class MultiBlocks(torch.nn.Module):
-    """MultiBlocks definition.
-
-    Args:
-        block_list: Individual blocks of the encoder architecture.
-        output_size: Architecture output size.
-        norm_class: Normalization module class.
-        norm_args: Normalization module arguments.
-
-    """
-
-    def __init__(
-        self,
-        block_list: List[torch.nn.Module],
-        output_size: int,
-        norm_class: torch.nn.Module = torch.nn.LayerNorm,
-        norm_args: Optional[Dict] = None,
-    ) -> None:
-        """Construct a MultiBlocks object."""
-        super().__init__()
-
-        self.blocks = torch.nn.ModuleList(block_list)
-        self.norm_blocks = norm_class(output_size, **norm_args)
-
-        self.num_blocks = len(block_list)
-
-    def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
-        """Initialize/Reset encoder streaming cache.
-
-        Args:
-            left_context: Number of left frames during chunk-by-chunk inference.
-            device: Device to use for cache tensor.
-
-        """
-        for idx in range(self.num_blocks):
-            self.blocks[idx].reset_streaming_cache(left_context, device)
-
-    def forward(
-        self,
-        x: torch.Tensor,
-        pos_enc: torch.Tensor,
-        mask: torch.Tensor,
-        chunk_mask: Optional[torch.Tensor] = None,
-    ) -> torch.Tensor:
-        """Forward each block of the encoder architecture.
-
-        Args:
-            x: MultiBlocks input sequences. (B, T, D_block_1)
-            pos_enc: Positional embedding sequences.
-            mask: Source mask. (B, T)
-            chunk_mask: Chunk mask. (T_2, T_2)
-
-        Returns:
-            x: Output sequences. (B, T, D_block_N)
-
-        """
-        for block_index, block in enumerate(self.blocks):
-            x, mask, pos_enc = block(x, pos_enc, mask, chunk_mask=chunk_mask)
-
-        x = self.norm_blocks(x)
-
-        return x
-
-    def chunk_forward(
-        self,
-        x: torch.Tensor,
-        pos_enc: torch.Tensor,
-        mask: torch.Tensor,
-        chunk_size: int = 0,
-        left_context: int = 0,
-        right_context: int = 0,
-    ) -> torch.Tensor:
-        """Forward each block of the encoder architecture.
-
-        Args:
-            x: MultiBlocks input sequences. (B, T, D_block_1)
-            pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_att)
-            mask: Source mask. (B, T_2)
-            left_context: Number of frames in left context.
-            right_context: Number of frames in right context.
-
-        Returns:
-            x: MultiBlocks output sequences. (B, T, D_block_N)
-
-        """
-        for block_idx, block in enumerate(self.blocks):
-            x, pos_enc = block.chunk_forward(
-                x,
-                pos_enc,
-                mask,
-                chunk_size=chunk_size,
-                left_context=left_context,
-                right_context=right_context,
-            )
-
-        x = self.norm_blocks(x)
-
-        return x
diff --git a/funasr/models/encoder/chunk_encoder_modules/positional_encoding.py b/funasr/models/encoder/chunk_encoder_modules/positional_encoding.py
deleted file mode 100644
index 5b56e26..0000000
--- a/funasr/models/encoder/chunk_encoder_modules/positional_encoding.py
+++ /dev/null
@@ -1,91 +0,0 @@
-"""Positional encoding modules."""
-
-import math
-
-import torch
-
-from funasr.modules.embedding import _pre_hook
-
-
-class RelPositionalEncoding(torch.nn.Module):
-    """Relative positional encoding.
-
-    Args:
-        size: Module size.
-        max_len: Maximum input length.
-        dropout_rate: Dropout rate.
-
-    """
-
-    def __init__(
-        self, size: int, dropout_rate: float = 0.0, max_len: int = 5000
-    ) -> None:
-        """Construct a RelativePositionalEncoding object."""
-        super().__init__()
-
-        self.size = size
-
-        self.pe = None
-        self.dropout = torch.nn.Dropout(p=dropout_rate)
-
-        self.extend_pe(torch.tensor(0.0).expand(1, max_len))
-        self._register_load_state_dict_pre_hook(_pre_hook)
-
-    def extend_pe(self, x: torch.Tensor, left_context: int = 0) -> None:
-        """Reset positional encoding.
-
-        Args:
-            x: Input sequences. (B, T, ?)
-            left_context: Number of frames in left context.
-
-        """
-        time1 = x.size(1) + left_context
-
-        if self.pe is not None:
-            if self.pe.size(1) >= time1 * 2 - 1:
-                if self.pe.dtype != x.dtype or self.pe.device != x.device:
-                    self.pe = self.pe.to(device=x.device, dtype=x.dtype)
-                return
-
-        pe_positive = torch.zeros(time1, self.size)
-        pe_negative = torch.zeros(time1, self.size)
-
-        position = torch.arange(0, time1, dtype=torch.float32).unsqueeze(1)
-        div_term = torch.exp(
-            torch.arange(0, self.size, 2, dtype=torch.float32)
-            * -(math.log(10000.0) / self.size)
-        )
-
-        pe_positive[:, 0::2] = torch.sin(position * div_term)
-        pe_positive[:, 1::2] = torch.cos(position * div_term)
-        pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
-
-        pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
-        pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
-        pe_negative = pe_negative[1:].unsqueeze(0)
-
-        self.pe = torch.cat([pe_positive, pe_negative], dim=1).to(
-            dtype=x.dtype, device=x.device
-        )
-
-    def forward(self, x: torch.Tensor, left_context: int = 0) -> torch.Tensor:
-        """Compute positional encoding.
-
-        Args:
-            x: Input sequences. (B, T, ?)
-            left_context: Number of frames in left context.
-
-        Returns:
-            pos_enc: Positional embedding sequences. (B, 2 * (T - 1), ?)
-
-        """
-        self.extend_pe(x, left_context=left_context)
-
-        time1 = x.size(1) + left_context
-
-        pos_enc = self.pe[
-            :, self.pe.size(1) // 2 - time1 + 1 : self.pe.size(1) // 2 + x.size(1)
-        ]
-        pos_enc = self.dropout(pos_enc)
-
-        return pos_enc
diff --git a/funasr/models/encoder/chunk_encoder_utils/building.py b/funasr/models/encoder/chunk_encoder_utils/building.py
deleted file mode 100644
index 21611aa..0000000
--- a/funasr/models/encoder/chunk_encoder_utils/building.py
+++ /dev/null
@@ -1,352 +0,0 @@
-"""Set of methods to build Transducer encoder architecture."""
-
-from typing import Any, Dict, List, Optional, Union
-
-from funasr.modules.activation import get_activation
-from funasr.models.encoder.chunk_encoder_blocks.branchformer import Branchformer
-from funasr.models.encoder.chunk_encoder_blocks.conformer import Conformer
-from funasr.models.encoder.chunk_encoder_blocks.conv1d import Conv1d
-from funasr.models.encoder.chunk_encoder_blocks.conv_input import ConvInput
-from funasr.models.encoder.chunk_encoder_blocks.linear_input import LinearInput
-from funasr.models.encoder.chunk_encoder_modules.attention import (  # noqa: H301
-    RelPositionMultiHeadedAttention,
-)
-from funasr.models.encoder.chunk_encoder_modules.convolution import (  # noqa: H301
-    ConformerConvolution,
-    ConvolutionalSpatialGatingUnit,
-)
-from funasr.models.encoder.chunk_encoder_modules.multi_blocks import MultiBlocks
-from funasr.models.encoder.chunk_encoder_modules.normalization import get_normalization
-from funasr.models.encoder.chunk_encoder_modules.positional_encoding import (  # noqa: H301
-    RelPositionalEncoding,
-)
-from funasr.modules.positionwise_feed_forward import (
-    PositionwiseFeedForward,
-)
-
-
-def build_main_parameters(
-    pos_wise_act_type: str = "swish",
-    conv_mod_act_type: str = "swish",
-    pos_enc_dropout_rate: float = 0.0,
-    pos_enc_max_len: int = 5000,
-    simplified_att_score: bool = False,
-    norm_type: str = "layer_norm",
-    conv_mod_norm_type: str = "layer_norm",
-    after_norm_eps: Optional[float] = None,
-    after_norm_partial: Optional[float] = None,
-    dynamic_chunk_training: bool = False,
-    short_chunk_threshold: float = 0.75,
-    short_chunk_size: int = 25,
-    left_chunk_size: int = 0,
-    time_reduction_factor: int = 1,
-    unified_model_training: bool = False,
-    default_chunk_size: int = 16,
-    jitter_range: int =4,
-    **activation_parameters,
-) -> Dict[str, Any]:
-    """Build encoder main parameters.
-
-    Args:
-        pos_wise_act_type: Conformer position-wise feed-forward activation type.
-        conv_mod_act_type: Conformer convolution module activation type.
-        pos_enc_dropout_rate: Positional encoding dropout rate.
-        pos_enc_max_len: Positional encoding maximum length.
-        simplified_att_score: Whether to use simplified attention score computation.
-        norm_type: X-former normalization module type.
-        conv_mod_norm_type: Conformer convolution module normalization type.
-        after_norm_eps: Epsilon value for the final normalization.
-        after_norm_partial: Value for the final normalization with RMSNorm.
-        dynamic_chunk_training: Whether to use dynamic chunk training.
-        short_chunk_threshold: Threshold for dynamic chunk selection.
-        short_chunk_size: Minimum number of frames during dynamic chunk training.
-        left_chunk_size: Number of frames in left context.
-        **activations_parameters: Parameters of the activation functions.
-                                    (See espnet2/asr_transducer/activation.py)
-
-    Returns:
-        : Main encoder parameters
-
-    """
-    main_params = {}
-
-    main_params["pos_wise_act"] = get_activation(
-        pos_wise_act_type, **activation_parameters
-    )
-
-    main_params["conv_mod_act"] = get_activation(
-        conv_mod_act_type, **activation_parameters
-    )
-
-    main_params["pos_enc_dropout_rate"] = pos_enc_dropout_rate
-    main_params["pos_enc_max_len"] = pos_enc_max_len
-
-    main_params["simplified_att_score"] = simplified_att_score
-
-    main_params["norm_type"] = norm_type
-    main_params["conv_mod_norm_type"] = conv_mod_norm_type
-
-    (
-        main_params["after_norm_class"],
-        main_params["after_norm_args"],
-    ) = get_normalization(norm_type, eps=after_norm_eps, partial=after_norm_partial)
-
-    main_params["dynamic_chunk_training"] = dynamic_chunk_training
-    main_params["short_chunk_threshold"] = max(0, short_chunk_threshold)
-    main_params["short_chunk_size"] = max(0, short_chunk_size)
-    main_params["left_chunk_size"] = max(0, left_chunk_size)
-    
-    main_params["unified_model_training"] = unified_model_training
-    main_params["default_chunk_size"] = max(0, default_chunk_size)
-    main_params["jitter_range"] = max(0, jitter_range)
-   
-    main_params["time_reduction_factor"] = time_reduction_factor
-
-    return main_params
-
-
-def build_positional_encoding(
-    block_size: int, configuration: Dict[str, Any]
-) -> RelPositionalEncoding:
-    """Build positional encoding block.
-
-    Args:
-        block_size: Input/output size.
-        configuration: Positional encoding configuration.
-
-    Returns:
-        : Positional encoding module.
-
-    """
-    return RelPositionalEncoding(
-        block_size,
-        configuration.get("pos_enc_dropout_rate", 0.0),
-        max_len=configuration.get("pos_enc_max_len", 5000),
-    )
-
-
-def build_input_block(
-    input_size: int,
-    configuration: Dict[str, Union[str, int]],
-) -> ConvInput:
-    """Build encoder input block.
-
-    Args:
-        input_size: Input size.
-        configuration: Input block configuration.
-
-    Returns:
-        : ConvInput block function.
-
-    """
-    if configuration["linear"]:
-        return LinearInput(
-            input_size,
-            configuration["output_size"],
-            configuration["subsampling_factor"],
-        )
-    else:
-        return ConvInput(
-            input_size,
-            configuration["conv_size"],
-            configuration["subsampling_factor"],
-            vgg_like=configuration["vgg_like"],
-            output_size=configuration["output_size"],
-        )
-
-
-def build_branchformer_block(
-    configuration: List[Dict[str, Any]],
-    main_params: Dict[str, Any],
-) -> Conformer:
-    """Build Branchformer block.
-
-    Args:
-        configuration: Branchformer block configuration.
-        main_params: Encoder main parameters.
-
-    Returns:
-        : Branchformer block function.
-
-    """
-    hidden_size = configuration["hidden_size"]
-    linear_size = configuration["linear_size"]
-
-    dropout_rate = configuration.get("dropout_rate", 0.0)
-
-    conv_mod_norm_class, conv_mod_norm_args = get_normalization(
-        main_params["conv_mod_norm_type"],
-        eps=configuration.get("conv_mod_norm_eps"),
-        partial=configuration.get("conv_mod_norm_partial"),
-    )
-
-    conv_mod_args = (
-        linear_size,
-        configuration["conv_mod_kernel_size"],
-        conv_mod_norm_class,
-        conv_mod_norm_args,
-        dropout_rate,
-        main_params["dynamic_chunk_training"],
-    )
-
-    mult_att_args = (
-        configuration.get("heads", 4),
-        hidden_size,
-        configuration.get("att_dropout_rate", 0.0),
-        main_params["simplified_att_score"],
-    )
-
-    norm_class, norm_args = get_normalization(
-        main_params["norm_type"],
-        eps=configuration.get("norm_eps"),
-        partial=configuration.get("norm_partial"),
-    )
-
-    return lambda: Branchformer(
-        hidden_size,
-        linear_size,
-        RelPositionMultiHeadedAttention(*mult_att_args),
-        ConvolutionalSpatialGatingUnit(*conv_mod_args),
-        norm_class=norm_class,
-        norm_args=norm_args,
-        dropout_rate=dropout_rate,
-    )
-
-
-def build_conformer_block(
-    configuration: List[Dict[str, Any]],
-    main_params: Dict[str, Any],
-) -> Conformer:
-    """Build Conformer block.
-
-    Args:
-        configuration: Conformer block configuration.
-        main_params: Encoder main parameters.
-
-    Returns:
-        : Conformer block function.
-
-    """
-    hidden_size = configuration["hidden_size"]
-    linear_size = configuration["linear_size"]
-
-    pos_wise_args = (
-        hidden_size,
-        linear_size,
-        configuration.get("pos_wise_dropout_rate", 0.0),
-        main_params["pos_wise_act"],
-    )
-
-    conv_mod_norm_args = {
-        "eps": configuration.get("conv_mod_norm_eps", 1e-05),
-        "momentum": configuration.get("conv_mod_norm_momentum", 0.1),
-    }
-
-    conv_mod_args = (
-        hidden_size,
-        configuration["conv_mod_kernel_size"],
-        main_params["conv_mod_act"],
-        conv_mod_norm_args,
-        main_params["dynamic_chunk_training"] or main_params["unified_model_training"],
-    )
-
-    mult_att_args = (
-        configuration.get("heads", 4),
-        hidden_size,
-        configuration.get("att_dropout_rate", 0.0),
-        main_params["simplified_att_score"],
-    )
-
-    norm_class, norm_args = get_normalization(
-        main_params["norm_type"],
-        eps=configuration.get("norm_eps"),
-        partial=configuration.get("norm_partial"),
-    )
-
-    return lambda: Conformer(
-        hidden_size,
-        RelPositionMultiHeadedAttention(*mult_att_args),
-        PositionwiseFeedForward(*pos_wise_args),
-        PositionwiseFeedForward(*pos_wise_args),
-        ConformerConvolution(*conv_mod_args),
-        norm_class=norm_class,
-        norm_args=norm_args,
-        dropout_rate=configuration.get("dropout_rate", 0.0),
-    )
-
-
-def build_conv1d_block(
-    configuration: List[Dict[str, Any]],
-    causal: bool,
-) -> Conv1d:
-    """Build Conv1d block.
-
-    Args:
-        configuration: Conv1d block configuration.
-
-    Returns:
-        : Conv1d block function.
-
-    """
-    return lambda: Conv1d(
-        configuration["input_size"],
-        configuration["output_size"],
-        configuration["kernel_size"],
-        stride=configuration.get("stride", 1),
-        dilation=configuration.get("dilation", 1),
-        groups=configuration.get("groups", 1),
-        bias=configuration.get("bias", True),
-        relu=configuration.get("relu", True),
-        batch_norm=configuration.get("batch_norm", False),
-        causal=causal,
-        dropout_rate=configuration.get("dropout_rate", 0.0),
-    )
-
-
-def build_body_blocks(
-    configuration: List[Dict[str, Any]],
-    main_params: Dict[str, Any],
-    output_size: int,
-) -> MultiBlocks:
-    """Build encoder body blocks.
-
-    Args:
-        configuration: Body blocks configuration.
-        main_params: Encoder main parameters.
-        output_size: Architecture output size.
-
-    Returns:
-        MultiBlocks function encapsulation all encoder blocks.
-
-    """
-    fn_modules = []
-    extended_conf = []
-
-    for c in configuration:
-        if c.get("num_blocks") is not None:
-            extended_conf += c["num_blocks"] * [
-                {c_i: c[c_i] for c_i in c if c_i != "num_blocks"}
-            ]
-        else:
-            extended_conf += [c]
-
-    for i, c in enumerate(extended_conf):
-        block_type = c["block_type"]
-
-        if block_type == "branchformer":
-            module = build_branchformer_block(c, main_params)
-        elif block_type == "conformer":
-            module = build_conformer_block(c, main_params)
-        elif block_type == "conv1d":
-            module = build_conv1d_block(c, main_params["dynamic_chunk_training"])
-        else:
-            raise NotImplementedError
-
-        fn_modules.append(module)
-
-    return MultiBlocks(
-        [fn() for fn in fn_modules],
-        output_size,
-        norm_class=main_params["after_norm_class"],
-        norm_args=main_params["after_norm_args"],
-    )
diff --git a/funasr/models/encoder/chunk_encoder_utils/validation.py b/funasr/models/encoder/chunk_encoder_utils/validation.py
deleted file mode 100644
index 1103cb9..0000000
--- a/funasr/models/encoder/chunk_encoder_utils/validation.py
+++ /dev/null
@@ -1,171 +0,0 @@
-"""Set of methods to validate encoder architecture."""
-
-from typing import Any, Dict, List, Tuple
-
-from funasr.modules.nets_utils import sub_factor_to_params
-
-
-def validate_block_arguments(
-    configuration: Dict[str, Any],
-    block_id: int,
-    previous_block_output: int,
-) -> Tuple[int, int]:
-    """Validate block arguments.
-
-    Args:
-        configuration: Architecture configuration.
-        block_id: Block ID.
-        previous_block_output: Previous block output size.
-
-    Returns:
-        input_size: Block input size.
-        output_size: Block output size.
-
-    """
-    block_type = configuration.get("block_type")
-
-    if block_type is None:
-        raise ValueError(
-            "Block %d in encoder doesn't have a type assigned. " % block_id
-        )
-
-    if block_type in ["branchformer", "conformer"]:
-        if configuration.get("linear_size") is None:
-            raise ValueError(
-                "Missing 'linear_size' argument for X-former block (ID: %d)" % block_id
-            )
-
-        if configuration.get("conv_mod_kernel_size") is None:
-            raise ValueError(
-                "Missing 'conv_mod_kernel_size' argument for X-former block (ID: %d)"
-                % block_id
-            )
-
-        input_size = configuration.get("hidden_size")
-        output_size = configuration.get("hidden_size")
-
-    elif block_type == "conv1d":
-        output_size = configuration.get("output_size")
-
-        if output_size is None:
-            raise ValueError(
-                "Missing 'output_size' argument for Conv1d block (ID: %d)" % block_id
-            )
-
-        if configuration.get("kernel_size") is None:
-            raise ValueError(
-                "Missing 'kernel_size' argument for Conv1d block (ID: %d)" % block_id
-            )
-
-        input_size = configuration["input_size"] = previous_block_output
-    else:
-        raise ValueError("Block type: %s is not supported." % block_type)
-
-    return input_size, output_size
-
-
-def validate_input_block(
-    configuration: Dict[str, Any], body_first_conf: Dict[str, Any], input_size: int
-) -> int:
-    """Validate input block.
-
-    Args:
-        configuration: Encoder input block configuration.
-        body_first_conf: Encoder first body block configuration.
-        input_size: Encoder input block input size.
-
-    Return:
-        output_size: Encoder input block output size.
-
-    """
-    vgg_like = configuration.get("vgg_like", False)
-    linear = configuration.get("linear", False)
-    next_block_type = body_first_conf.get("block_type")
-    allowed_next_block_type = ["branchformer", "conformer", "conv1d"]
-
-    if next_block_type is None or (next_block_type not in allowed_next_block_type):
-        return -1
-
-    if configuration.get("subsampling_factor") is None:
-        configuration["subsampling_factor"] = 4
-
-    if vgg_like:
-        conv_size = configuration.get("conv_size", (64, 128))
-
-        if isinstance(conv_size, int):
-            conv_size = (conv_size, conv_size)
-    else:
-        conv_size = configuration.get("conv_size", None)
-
-        if isinstance(conv_size, tuple):
-            conv_size = conv_size[0]
-
-    if next_block_type == "conv1d":
-        if vgg_like:
-            output_size = conv_size[1] * ((input_size // 2) // 2)
-        else:
-            if conv_size is None:
-                conv_size = body_first_conf.get("output_size", 64)
-
-            sub_factor = configuration["subsampling_factor"]
-
-            _, _, conv_osize = sub_factor_to_params(sub_factor, input_size)
-            assert (
-                conv_osize > 0
-            ), "Conv2D output size is <1 with input size %d and subsampling %d" % (
-                input_size,
-                sub_factor,
-            )
-
-            output_size = conv_osize * conv_size
-
-        configuration["output_size"] = None
-    else:
-        output_size = body_first_conf.get("hidden_size")
-
-        if conv_size is None:
-            conv_size = output_size
-
-        configuration["output_size"] = output_size
-
-    configuration["conv_size"] = conv_size
-    configuration["vgg_like"] = vgg_like
-    configuration["linear"] = linear
-
-    return output_size
-
-
-def validate_architecture(
-    input_conf: Dict[str, Any], body_conf: List[Dict[str, Any]], input_size: int
-) -> Tuple[int, int]:
-    """Validate specified architecture is valid.
-
-    Args:
-        input_conf: Encoder input block configuration.
-        body_conf: Encoder body blocks configuration.
-        input_size: Encoder input size.
-
-    Returns:
-        input_block_osize: Encoder input block output size.
-        : Encoder body block output size.
-
-    """
-    input_block_osize = validate_input_block(input_conf, body_conf[0], input_size)
-
-    cmp_io = []
-
-    for i, b in enumerate(body_conf):
-        _io = validate_block_arguments(
-            b, (i + 1), input_block_osize if i == 0 else cmp_io[i - 1][1]
-        )
-
-        cmp_io.append(_io)
-
-    for i in range(1, len(cmp_io)):
-        if cmp_io[(i - 1)][1] != cmp_io[i][0]:
-            raise ValueError(
-                "Output/Input mismatch between blocks %d and %d"
-                " in the encoder body." % ((i - 1), i)
-            )
-
-    return input_block_osize, cmp_io[-1][1]
diff --git a/funasr/models/encoder/conformer_encoder.py b/funasr/models/encoder/conformer_encoder.py
index 7c7f661..c837cf5 100644
--- a/funasr/models/encoder/conformer_encoder.py
+++ b/funasr/models/encoder/conformer_encoder.py
@@ -8,6 +8,7 @@
 from typing import Optional
 from typing import Tuple
 from typing import Union
+from typing import Dict
 
 import torch
 from torch import nn
@@ -18,6 +19,7 @@
 from funasr.modules.attention import (
     MultiHeadedAttention,  # noqa: H301
     RelPositionMultiHeadedAttention,  # noqa: H301
+    RelPositionMultiHeadedAttentionChunk,
     LegacyRelPositionMultiHeadedAttention,  # noqa: H301
 )
 from funasr.modules.embedding import (
@@ -25,16 +27,24 @@
     ScaledPositionalEncoding,  # noqa: H301
     RelPositionalEncoding,  # noqa: H301
     LegacyRelPositionalEncoding,  # noqa: H301
+    StreamingRelPositionalEncoding,
 )
 from funasr.modules.layer_norm import LayerNorm
+from funasr.modules.normalization import get_normalization
 from funasr.modules.multi_layer_conv import Conv1dLinear
 from funasr.modules.multi_layer_conv import MultiLayeredConv1d
 from funasr.modules.nets_utils import get_activation
 from funasr.modules.nets_utils import make_pad_mask
+from funasr.modules.nets_utils import (
+    TooShortUttError,
+    check_short_utt,
+    make_chunk_mask,
+    make_source_mask,
+)
 from funasr.modules.positionwise_feed_forward import (
     PositionwiseFeedForward,  # noqa: H301
 )
-from funasr.modules.repeat import repeat
+from funasr.modules.repeat import repeat, MultiBlocks
 from funasr.modules.subsampling import Conv2dSubsampling
 from funasr.modules.subsampling import Conv2dSubsampling2
 from funasr.modules.subsampling import Conv2dSubsampling6
@@ -42,6 +52,8 @@
 from funasr.modules.subsampling import TooShortUttError
 from funasr.modules.subsampling import check_short_utt
 from funasr.modules.subsampling import Conv2dSubsamplingPad
+from funasr.modules.subsampling import StreamingConvInput
+
 class ConvolutionModule(nn.Module):
     """ConvolutionModule in Conformer model.
 
@@ -275,6 +287,188 @@
             return (x, pos_emb), mask
 
         return x, mask
+
+class ChunkEncoderLayer(torch.nn.Module):
+    """Chunk Conformer module definition.
+    Args:
+        block_size: Input/output size.
+        self_att: Self-attention module instance.
+        feed_forward: Feed-forward module instance.
+        feed_forward_macaron: Feed-forward module instance for macaron network.
+        conv_mod: Convolution module instance.
+        norm_class: Normalization module class.
+        norm_args: Normalization module arguments.
+        dropout_rate: Dropout rate.
+    """
+
+    def __init__(
+        self,
+        block_size: int,
+        self_att: torch.nn.Module,
+        feed_forward: torch.nn.Module,
+        feed_forward_macaron: torch.nn.Module,
+        conv_mod: torch.nn.Module,
+        norm_class: torch.nn.Module = torch.nn.LayerNorm,
+        norm_args: Dict = {},
+        dropout_rate: float = 0.0,
+    ) -> None:
+        """Construct a Conformer object."""
+        super().__init__()
+
+        self.self_att = self_att
+
+        self.feed_forward = feed_forward
+        self.feed_forward_macaron = feed_forward_macaron
+        self.feed_forward_scale = 0.5
+
+        self.conv_mod = conv_mod
+
+        self.norm_feed_forward = norm_class(block_size, **norm_args)
+        self.norm_self_att = norm_class(block_size, **norm_args)
+
+        self.norm_macaron = norm_class(block_size, **norm_args)
+        self.norm_conv = norm_class(block_size, **norm_args)
+        self.norm_final = norm_class(block_size, **norm_args)
+
+        self.dropout = torch.nn.Dropout(dropout_rate)
+
+        self.block_size = block_size
+        self.cache = None
+
+    def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
+        """Initialize/Reset self-attention and convolution modules cache for streaming.
+        Args:
+            left_context: Number of left frames during chunk-by-chunk inference.
+            device: Device to use for cache tensor.
+        """
+        self.cache = [
+            torch.zeros(
+                (1, left_context, self.block_size),
+                device=device,
+            ),
+            torch.zeros(
+                (
+                    1,
+                    self.block_size,
+                    self.conv_mod.kernel_size - 1,
+                ),
+                device=device,
+            ),
+        ]
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        pos_enc: torch.Tensor,
+        mask: torch.Tensor,
+        chunk_mask: Optional[torch.Tensor] = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+        """Encode input sequences.
+        Args:
+            x: Conformer input sequences. (B, T, D_block)
+            pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
+            mask: Source mask. (B, T)
+            chunk_mask: Chunk mask. (T_2, T_2)
+        Returns:
+            x: Conformer output sequences. (B, T, D_block)
+            mask: Source mask. (B, T)
+            pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
+        """
+        residual = x
+
+        x = self.norm_macaron(x)
+        x = residual + self.feed_forward_scale * self.dropout(
+            self.feed_forward_macaron(x)
+        )
+
+        residual = x
+        x = self.norm_self_att(x)
+        x_q = x
+        x = residual + self.dropout(
+            self.self_att(
+                x_q,
+                x,
+                x,
+                pos_enc,
+                mask,
+                chunk_mask=chunk_mask,
+            )
+        )
+
+        residual = x
+
+        x = self.norm_conv(x)
+        x, _ = self.conv_mod(x)
+        x = residual + self.dropout(x)
+        residual = x
+
+        x = self.norm_feed_forward(x)
+        x = residual + self.feed_forward_scale * self.dropout(self.feed_forward(x))
+
+        x = self.norm_final(x)
+        return x, mask, pos_enc
+
+    def chunk_forward(
+        self,
+        x: torch.Tensor,
+        pos_enc: torch.Tensor,
+        mask: torch.Tensor,
+        chunk_size: int = 16,
+        left_context: int = 0,
+        right_context: int = 0,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Encode chunk of input sequence.
+        Args:
+            x: Conformer input sequences. (B, T, D_block)
+            pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
+            mask: Source mask. (B, T_2)
+            left_context: Number of frames in left context.
+            right_context: Number of frames in right context.
+        Returns:
+            x: Conformer output sequences. (B, T, D_block)
+            pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
+        """
+        residual = x
+
+        x = self.norm_macaron(x)
+        x = residual + self.feed_forward_scale * self.feed_forward_macaron(x)
+
+        residual = x
+        x = self.norm_self_att(x)
+        if left_context > 0:
+            key = torch.cat([self.cache[0], x], dim=1)
+        else:
+            key = x
+        val = key
+
+        if right_context > 0:
+            att_cache = key[:, -(left_context + right_context) : -right_context, :]
+        else:
+            att_cache = key[:, -left_context:, :]
+        x = residual + self.self_att(
+            x,
+            key,
+            val,
+            pos_enc,
+            mask,
+            left_context=left_context,
+        )
+
+        residual = x
+        x = self.norm_conv(x)
+        x, conv_cache = self.conv_mod(
+            x, cache=self.cache[1], right_context=right_context
+        )
+        x = residual + x
+        residual = x
+
+        x = self.norm_feed_forward(x)
+        x = residual + self.feed_forward_scale * self.feed_forward(x)
+
+        x = self.norm_final(x)
+        self.cache = [att_cache, conv_cache]
+
+        return x, pos_enc
 
 
 class ConformerEncoder(AbsEncoder):
@@ -604,3 +798,447 @@
         if len(intermediate_outs) > 0:
             return (xs_pad, intermediate_outs), olens, None
         return xs_pad, olens, None
+
+
+class CausalConvolution(torch.nn.Module):
+    """ConformerConvolution module definition.
+    Args:
+        channels: The number of channels.
+        kernel_size: Size of the convolving kernel.
+        activation: Type of activation function.
+        norm_args: Normalization module arguments.
+        causal: Whether to use causal convolution (set to True if streaming).
+    """
+
+    def __init__(
+        self,
+        channels: int,
+        kernel_size: int,
+        activation: torch.nn.Module = torch.nn.ReLU(),
+        norm_args: Dict = {},
+        causal: bool = False,
+    ) -> None:
+        """Construct an ConformerConvolution object."""
+        super().__init__()
+
+        assert (kernel_size - 1) % 2 == 0
+
+        self.kernel_size = kernel_size
+
+        self.pointwise_conv1 = torch.nn.Conv1d(
+            channels,
+            2 * channels,
+            kernel_size=1,
+            stride=1,
+            padding=0,
+        )
+
+        if causal:
+            self.lorder = kernel_size - 1
+            padding = 0
+        else:
+            self.lorder = 0
+            padding = (kernel_size - 1) // 2
+
+        self.depthwise_conv = torch.nn.Conv1d(
+            channels,
+            channels,
+            kernel_size,
+            stride=1,
+            padding=padding,
+            groups=channels,
+        )
+        self.norm = torch.nn.BatchNorm1d(channels, **norm_args)
+        self.pointwise_conv2 = torch.nn.Conv1d(
+            channels,
+            channels,
+            kernel_size=1,
+            stride=1,
+            padding=0,
+        )
+
+        self.activation = activation
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        cache: Optional[torch.Tensor] = None,
+        right_context: int = 0,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Compute convolution module.
+        Args:
+            x: ConformerConvolution input sequences. (B, T, D_hidden)
+            cache: ConformerConvolution input cache. (1, conv_kernel, D_hidden)
+            right_context: Number of frames in right context.
+        Returns:
+            x: ConformerConvolution output sequences. (B, T, D_hidden)
+            cache: ConformerConvolution output cache. (1, conv_kernel, D_hidden)
+        """
+        x = self.pointwise_conv1(x.transpose(1, 2))
+        x = torch.nn.functional.glu(x, dim=1)
+
+        if self.lorder > 0:
+            if cache is None:
+                x = torch.nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
+            else:
+                x = torch.cat([cache, x], dim=2)
+
+                if right_context > 0:
+                    cache = x[:, :, -(self.lorder + right_context) : -right_context]
+                else:
+                    cache = x[:, :, -self.lorder :]
+
+        x = self.depthwise_conv(x)
+        x = self.activation(self.norm(x))
+
+        x = self.pointwise_conv2(x).transpose(1, 2)
+
+        return x, cache
+
+class ConformerChunkEncoder(torch.nn.Module):
+    """Encoder module definition.
+    Args:
+        input_size: Input size.
+        body_conf: Encoder body configuration.
+        input_conf: Encoder input configuration.
+        main_conf: Encoder main configuration.
+    """
+
+    def __init__(
+        self,
+        input_size: int,
+        output_size: int = 256,
+        attention_heads: int = 4,
+        linear_units: int = 2048,
+        num_blocks: int = 6,
+        dropout_rate: float = 0.1,
+        positional_dropout_rate: float = 0.1,
+        attention_dropout_rate: float = 0.0,
+        embed_vgg_like: bool = False,
+        normalize_before: bool = True,
+        concat_after: bool = False,
+        positionwise_layer_type: str = "linear",
+        positionwise_conv_kernel_size: int = 3,
+        macaron_style: bool = False,
+        rel_pos_type: str = "legacy",
+        pos_enc_layer_type: str = "rel_pos",
+        selfattention_layer_type: str = "rel_selfattn",
+        activation_type: str = "swish",
+        use_cnn_module: bool = True,
+        zero_triu: bool = False,
+        norm_type: str = "layer_norm",
+        cnn_module_kernel: int = 31,
+        conv_mod_norm_eps: float = 0.00001,
+        conv_mod_norm_momentum: float = 0.1,
+        simplified_att_score: bool = False,
+        dynamic_chunk_training: bool = False,
+        short_chunk_threshold: float = 0.75,
+        short_chunk_size: int = 25,
+        left_chunk_size: int = 0,
+        time_reduction_factor: int = 1,
+        unified_model_training: bool = False,
+        default_chunk_size: int = 16,
+        jitter_range: int = 4,
+        subsampling_factor: int = 1,
+        **activation_parameters,
+    ) -> None:
+        """Construct an Encoder object."""
+        super().__init__()
+
+        assert check_argument_types()
+
+        self.embed = StreamingConvInput(
+            input_size,
+            output_size,
+            subsampling_factor,
+            vgg_like=embed_vgg_like,
+            output_size=output_size,
+        )
+
+        self.pos_enc = StreamingRelPositionalEncoding(
+            output_size,
+            positional_dropout_rate,
+        )
+
+        activation = get_activation(
+            activation_type, **activation_parameters
+       )        
+
+        pos_wise_args = (
+            output_size,
+            linear_units,
+            positional_dropout_rate,
+            activation,
+        )
+
+        conv_mod_norm_args = {
+            "eps": conv_mod_norm_eps,
+            "momentum": conv_mod_norm_momentum,
+        }
+
+        conv_mod_args = (
+            output_size,
+            cnn_module_kernel,
+            activation,
+            conv_mod_norm_args,
+            dynamic_chunk_training or unified_model_training,
+        )
+
+        mult_att_args = (
+            attention_heads,
+            output_size,
+            attention_dropout_rate,
+            simplified_att_score,
+        )
+
+        norm_class, norm_args = get_normalization(
+            norm_type,
+        )
+
+        fn_modules = []
+        for _ in range(num_blocks):
+            module = lambda: ChunkEncoderLayer(
+                output_size,
+                RelPositionMultiHeadedAttentionChunk(*mult_att_args),
+                PositionwiseFeedForward(*pos_wise_args),
+                PositionwiseFeedForward(*pos_wise_args),
+                CausalConvolution(*conv_mod_args),
+                norm_class=norm_class,
+                norm_args=norm_args,
+                dropout_rate=dropout_rate,
+            )
+            fn_modules.append(module)        
+
+        self.encoders = MultiBlocks(
+            [fn() for fn in fn_modules],
+            output_size,
+            norm_class=norm_class,
+            norm_args=norm_args,
+        )
+
+        self.output_size = output_size
+
+        self.dynamic_chunk_training = dynamic_chunk_training
+        self.short_chunk_threshold = short_chunk_threshold
+        self.short_chunk_size = short_chunk_size
+        self.left_chunk_size = left_chunk_size
+
+        self.unified_model_training = unified_model_training
+        self.default_chunk_size = default_chunk_size
+        self.jitter_range = jitter_range
+
+        self.time_reduction_factor = time_reduction_factor
+
+    def get_encoder_input_raw_size(self, size: int, hop_length: int) -> int:
+        """Return the corresponding number of sample for a given chunk size, in frames.
+        Where size is the number of features frames after applying subsampling.
+        Args:
+            size: Number of frames after subsampling.
+            hop_length: Frontend's hop length
+        Returns:
+            : Number of raw samples
+        """
+        return self.embed.get_size_before_subsampling(size) * hop_length
+
+    def get_encoder_input_size(self, size: int) -> int:
+        """Return the corresponding number of sample for a given chunk size, in frames.
+        Where size is the number of features frames after applying subsampling.
+        Args:
+            size: Number of frames after subsampling.
+        Returns:
+            : Number of raw samples
+        """
+        return self.embed.get_size_before_subsampling(size)
+
+
+    def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
+        """Initialize/Reset encoder streaming cache.
+        Args:
+            left_context: Number of frames in left context.
+            device: Device ID.
+        """
+        return self.encoders.reset_streaming_cache(left_context, device)
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        x_len: torch.Tensor,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Encode input sequences.
+        Args:
+            x: Encoder input features. (B, T_in, F)
+            x_len: Encoder input features lengths. (B,)
+        Returns:
+           x: Encoder outputs. (B, T_out, D_enc)
+           x_len: Encoder outputs lenghts. (B,)
+        """
+        short_status, limit_size = check_short_utt(
+            self.embed.subsampling_factor, x.size(1)
+        )
+
+        if short_status:
+            raise TooShortUttError(
+                f"has {x.size(1)} frames and is too short for subsampling "
+                + f"(it needs more than {limit_size} frames), return empty results",
+                x.size(1),
+                limit_size,
+            )
+
+        mask = make_source_mask(x_len)
+
+        if self.unified_model_training:
+            chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item()
+            x, mask = self.embed(x, mask, chunk_size)
+            pos_enc = self.pos_enc(x)
+            chunk_mask = make_chunk_mask(
+                x.size(1),
+                chunk_size,
+                left_chunk_size=self.left_chunk_size,
+                device=x.device,
+            )
+            x_utt = self.encoders(
+                x,
+                pos_enc,
+                mask,
+                chunk_mask=None,
+            )
+            x_chunk = self.encoders(
+                x,
+                pos_enc,
+                mask,
+                chunk_mask=chunk_mask,
+            )
+
+            olens = mask.eq(0).sum(1)
+            if self.time_reduction_factor > 1:
+                x_utt = x_utt[:,::self.time_reduction_factor,:]
+                x_chunk = x_chunk[:,::self.time_reduction_factor,:]
+                olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
+
+            return x_utt, x_chunk, olens
+
+        elif self.dynamic_chunk_training:
+            max_len = x.size(1)
+            chunk_size = torch.randint(1, max_len, (1,)).item()
+
+            if chunk_size > (max_len * self.short_chunk_threshold):
+                chunk_size = max_len
+            else:
+                chunk_size = (chunk_size % self.short_chunk_size) + 1
+
+            x, mask = self.embed(x, mask, chunk_size)
+            pos_enc = self.pos_enc(x)
+
+            chunk_mask = make_chunk_mask(
+                x.size(1),
+                chunk_size,
+                left_chunk_size=self.left_chunk_size,
+                device=x.device,
+            )
+        else:
+            x, mask = self.embed(x, mask, None)
+            pos_enc = self.pos_enc(x)
+            chunk_mask = None
+        x = self.encoders(
+            x,
+            pos_enc,
+            mask,
+            chunk_mask=chunk_mask,
+        )
+
+        olens = mask.eq(0).sum(1)
+        if self.time_reduction_factor > 1:
+            x = x[:,::self.time_reduction_factor,:]
+            olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
+
+        return x, olens
+
+    def simu_chunk_forward(
+        self,
+        x: torch.Tensor,
+        x_len: torch.Tensor,
+        chunk_size: int = 16,
+        left_context: int = 32,
+        right_context: int = 0,
+    ) -> torch.Tensor:
+        short_status, limit_size = check_short_utt(
+            self.embed.subsampling_factor, x.size(1)
+        )
+
+        if short_status:
+            raise TooShortUttError(
+                f"has {x.size(1)} frames and is too short for subsampling "
+                + f"(it needs more than {limit_size} frames), return empty results",
+                x.size(1),
+                limit_size,
+            )
+
+        mask = make_source_mask(x_len)
+
+        x, mask = self.embed(x, mask, chunk_size)
+        pos_enc = self.pos_enc(x)
+        chunk_mask = make_chunk_mask(
+            x.size(1),
+            chunk_size,
+            left_chunk_size=self.left_chunk_size,
+            device=x.device,
+        )
+
+        x = self.encoders(
+            x,
+            pos_enc,
+            mask,
+            chunk_mask=chunk_mask,
+        )
+        olens = mask.eq(0).sum(1)
+        if self.time_reduction_factor > 1:
+            x = x[:,::self.time_reduction_factor,:]
+
+        return x
+
+    def chunk_forward(
+        self,
+        x: torch.Tensor,
+        x_len: torch.Tensor,
+        processed_frames: torch.tensor,
+        chunk_size: int = 16,
+        left_context: int = 32,
+        right_context: int = 0,
+    ) -> torch.Tensor:
+        """Encode input sequences as chunks.
+        Args:
+            x: Encoder input features. (1, T_in, F)
+            x_len: Encoder input features lengths. (1,)
+            processed_frames: Number of frames already seen.
+            left_context: Number of frames in left context.
+            right_context: Number of frames in right context.
+        Returns:
+           x: Encoder outputs. (B, T_out, D_enc)
+        """
+        mask = make_source_mask(x_len)
+        x, mask = self.embed(x, mask, None)
+
+        if left_context > 0:
+            processed_mask = (
+                torch.arange(left_context, device=x.device)
+                .view(1, left_context)
+                .flip(1)
+            )
+            processed_mask = processed_mask >= processed_frames
+            mask = torch.cat([processed_mask, mask], dim=1)
+        pos_enc = self.pos_enc(x, left_context=left_context)
+        x = self.encoders.chunk_forward(
+            x,
+            pos_enc,
+            mask,
+            chunk_size=chunk_size,
+            left_context=left_context,
+            right_context=right_context,
+        )
+
+        if right_context > 0:
+            x = x[:, 0:-right_context, :]
+
+        if self.time_reduction_factor > 1:
+            x = x[:,::self.time_reduction_factor,:]
+        return x
diff --git a/funasr/modules/attention.py b/funasr/modules/attention.py
index 31d5a87..6202079 100644
--- a/funasr/modules/attention.py
+++ b/funasr/modules/attention.py
@@ -11,7 +11,7 @@
 import numpy
 import torch
 from torch import nn
-
+from typing import Optional, Tuple
 
 class MultiHeadedAttention(nn.Module):
     """Multi-Head Attention layer.
@@ -741,3 +741,221 @@
         scores = torch.matmul(q_h, k_h.transpose(-2, -1))
         att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
         return att_outs
+
+class RelPositionMultiHeadedAttentionChunk(torch.nn.Module):
+    """RelPositionMultiHeadedAttention definition.
+    Args:
+        num_heads: Number of attention heads.
+        embed_size: Embedding size.
+        dropout_rate: Dropout rate.
+    """
+
+    def __init__(
+        self,
+        num_heads: int,
+        embed_size: int,
+        dropout_rate: float = 0.0,
+        simplified_attention_score: bool = False,
+    ) -> None:
+        """Construct an MultiHeadedAttention object."""
+        super().__init__()
+
+        self.d_k = embed_size // num_heads
+        self.num_heads = num_heads
+
+        assert self.d_k * num_heads == embed_size, (
+            "embed_size (%d) must be divisible by num_heads (%d)",
+            (embed_size, num_heads),
+        )
+
+        self.linear_q = torch.nn.Linear(embed_size, embed_size)
+        self.linear_k = torch.nn.Linear(embed_size, embed_size)
+        self.linear_v = torch.nn.Linear(embed_size, embed_size)
+
+        self.linear_out = torch.nn.Linear(embed_size, embed_size)
+
+        if simplified_attention_score:
+            self.linear_pos = torch.nn.Linear(embed_size, num_heads)
+
+            self.compute_att_score = self.compute_simplified_attention_score
+        else:
+            self.linear_pos = torch.nn.Linear(embed_size, embed_size, bias=False)
+
+            self.pos_bias_u = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k))
+            self.pos_bias_v = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k))
+            torch.nn.init.xavier_uniform_(self.pos_bias_u)
+            torch.nn.init.xavier_uniform_(self.pos_bias_v)
+
+            self.compute_att_score = self.compute_attention_score
+
+        self.dropout = torch.nn.Dropout(p=dropout_rate)
+        self.attn = None
+
+    def rel_shift(self, x: torch.Tensor, left_context: int = 0) -> torch.Tensor:
+        """Compute relative positional encoding.
+        Args:
+            x: Input sequence. (B, H, T_1, 2 * T_1 - 1)
+            left_context: Number of frames in left context.
+        Returns:
+            x: Output sequence. (B, H, T_1, T_2)
+        """
+        batch_size, n_heads, time1, n = x.shape
+        time2 = time1 + left_context
+
+        batch_stride, n_heads_stride, time1_stride, n_stride = x.stride()
+
+        return x.as_strided(
+            (batch_size, n_heads, time1, time2),
+            (batch_stride, n_heads_stride, time1_stride - n_stride, n_stride),
+            storage_offset=(n_stride * (time1 - 1)),
+        )
+
+    def compute_simplified_attention_score(
+        self,
+        query: torch.Tensor,
+        key: torch.Tensor,
+        pos_enc: torch.Tensor,
+        left_context: int = 0,
+    ) -> torch.Tensor:
+        """Simplified attention score computation.
+        Reference: https://github.com/k2-fsa/icefall/pull/458
+        Args:
+            query: Transformed query tensor. (B, H, T_1, d_k)
+            key: Transformed key tensor. (B, H, T_2, d_k)
+            pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
+            left_context: Number of frames in left context.
+        Returns:
+            : Attention score. (B, H, T_1, T_2)
+        """
+        pos_enc = self.linear_pos(pos_enc)
+
+        matrix_ac = torch.matmul(query, key.transpose(2, 3))
+
+        matrix_bd = self.rel_shift(
+            pos_enc.transpose(1, 2).unsqueeze(2).repeat(1, 1, query.size(2), 1),
+            left_context=left_context,
+        )
+
+        return (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
+
+    def compute_attention_score(
+        self,
+        query: torch.Tensor,
+        key: torch.Tensor,
+        pos_enc: torch.Tensor,
+        left_context: int = 0,
+    ) -> torch.Tensor:
+        """Attention score computation.
+        Args:
+            query: Transformed query tensor. (B, H, T_1, d_k)
+            key: Transformed key tensor. (B, H, T_2, d_k)
+            pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
+            left_context: Number of frames in left context.
+        Returns:
+            : Attention score. (B, H, T_1, T_2)
+        """
+        p = self.linear_pos(pos_enc).view(pos_enc.size(0), -1, self.num_heads, self.d_k)
+
+        query = query.transpose(1, 2)
+        q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2)
+        q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2)
+
+        matrix_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1))
+
+        matrix_bd = torch.matmul(q_with_bias_v, p.permute(0, 2, 3, 1))
+        matrix_bd = self.rel_shift(matrix_bd, left_context=left_context)
+
+        return (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
+
+    def forward_qkv(
+        self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
+    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+        """Transform query, key and value.
+        Args:
+            query: Query tensor. (B, T_1, size)
+            key: Key tensor. (B, T_2, size)
+            v: Value tensor. (B, T_2, size)
+        Returns:
+            q: Transformed query tensor. (B, H, T_1, d_k)
+            k: Transformed key tensor. (B, H, T_2, d_k)
+            v: Transformed value tensor. (B, H, T_2, d_k)
+        """
+        n_batch = query.size(0)
+
+        q = (
+            self.linear_q(query)
+            .view(n_batch, -1, self.num_heads, self.d_k)
+            .transpose(1, 2)
+        )
+        k = (
+            self.linear_k(key)
+            .view(n_batch, -1, self.num_heads, self.d_k)
+            .transpose(1, 2)
+        )
+        v = (
+            self.linear_v(value)
+            .view(n_batch, -1, self.num_heads, self.d_k)
+            .transpose(1, 2)
+        )
+
+        return q, k, v
+
+    def forward_attention(
+        self,
+        value: torch.Tensor,
+        scores: torch.Tensor,
+        mask: torch.Tensor,
+        chunk_mask: Optional[torch.Tensor] = None,
+    ) -> torch.Tensor:
+        """Compute attention context vector.
+        Args:
+            value: Transformed value. (B, H, T_2, d_k)
+            scores: Attention score. (B, H, T_1, T_2)
+            mask: Source mask. (B, T_2)
+            chunk_mask: Chunk mask. (T_1, T_1)
+        Returns:
+           attn_output: Transformed value weighted by attention score. (B, T_1, H * d_k)
+        """
+        batch_size = scores.size(0)
+        mask = mask.unsqueeze(1).unsqueeze(2)
+        if chunk_mask is not None:
+            mask = chunk_mask.unsqueeze(0).unsqueeze(1) | mask
+        scores = scores.masked_fill(mask, float("-inf"))
+        self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0)
+
+        attn_output = self.dropout(self.attn)
+        attn_output = torch.matmul(attn_output, value)
+
+        attn_output = self.linear_out(
+            attn_output.transpose(1, 2)
+            .contiguous()
+            .view(batch_size, -1, self.num_heads * self.d_k)
+        )
+
+        return attn_output
+
+    def forward(
+        self,
+        query: torch.Tensor,
+        key: torch.Tensor,
+        value: torch.Tensor,
+        pos_enc: torch.Tensor,
+        mask: torch.Tensor,
+        chunk_mask: Optional[torch.Tensor] = None,
+        left_context: int = 0,
+    ) -> torch.Tensor:
+        """Compute scaled dot product attention with rel. positional encoding.
+        Args:
+            query: Query tensor. (B, T_1, size)
+            key: Key tensor. (B, T_2, size)
+            value: Value tensor. (B, T_2, size)
+            pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
+            mask: Source mask. (B, T_2)
+            chunk_mask: Chunk mask. (T_1, T_1)
+            left_context: Number of frames in left context.
+        Returns:
+            : Output tensor. (B, T_1, H * d_k)
+        """
+        q, k, v = self.forward_qkv(query, key, value)
+        scores = self.compute_att_score(q, k, pos_enc, left_context=left_context)
+        return self.forward_attention(v, scores, mask, chunk_mask=chunk_mask)
diff --git a/funasr/modules/embedding.py b/funasr/modules/embedding.py
index 79ca0b2..e0070de 100644
--- a/funasr/modules/embedding.py
+++ b/funasr/modules/embedding.py
@@ -423,4 +423,79 @@
         outputs = F.pad(outputs, (pad_left, pad_right))
         outputs = outputs.transpose(1,2)
         return outputs
-       
+
+class StreamingRelPositionalEncoding(torch.nn.Module):
+    """Relative positional encoding.
+    Args:
+        size: Module size.
+        max_len: Maximum input length.
+        dropout_rate: Dropout rate.
+    """
+
+    def __init__(
+        self, size: int, dropout_rate: float = 0.0, max_len: int = 5000
+    ) -> None:
+        """Construct a RelativePositionalEncoding object."""
+        super().__init__()
+
+        self.size = size
+
+        self.pe = None
+        self.dropout = torch.nn.Dropout(p=dropout_rate)
+
+        self.extend_pe(torch.tensor(0.0).expand(1, max_len))
+        self._register_load_state_dict_pre_hook(_pre_hook)
+
+    def extend_pe(self, x: torch.Tensor, left_context: int = 0) -> None:
+        """Reset positional encoding.
+        Args:
+            x: Input sequences. (B, T, ?)
+            left_context: Number of frames in left context.
+        """
+        time1 = x.size(1) + left_context
+
+        if self.pe is not None:
+            if self.pe.size(1) >= time1 * 2 - 1:
+                if self.pe.dtype != x.dtype or self.pe.device != x.device:
+                    self.pe = self.pe.to(device=x.device, dtype=x.dtype)
+                return
+
+        pe_positive = torch.zeros(time1, self.size)
+        pe_negative = torch.zeros(time1, self.size)
+
+        position = torch.arange(0, time1, dtype=torch.float32).unsqueeze(1)
+        div_term = torch.exp(
+            torch.arange(0, self.size, 2, dtype=torch.float32)
+            * -(math.log(10000.0) / self.size)
+        )
+
+        pe_positive[:, 0::2] = torch.sin(position * div_term)
+        pe_positive[:, 1::2] = torch.cos(position * div_term)
+        pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
+
+        pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
+        pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
+        pe_negative = pe_negative[1:].unsqueeze(0)
+
+        self.pe = torch.cat([pe_positive, pe_negative], dim=1).to(
+            dtype=x.dtype, device=x.device
+        )
+
+    def forward(self, x: torch.Tensor, left_context: int = 0) -> torch.Tensor:
+        """Compute positional encoding.
+        Args:
+            x: Input sequences. (B, T, ?)
+            left_context: Number of frames in left context.
+        Returns:
+            pos_enc: Positional embedding sequences. (B, 2 * (T - 1), ?)
+        """
+        self.extend_pe(x, left_context=left_context)
+
+        time1 = x.size(1) + left_context
+
+        pos_enc = self.pe[
+            :, self.pe.size(1) // 2 - time1 + 1 : self.pe.size(1) // 2 + x.size(1)
+        ]
+        pos_enc = self.dropout(pos_enc)
+
+        return pos_enc
diff --git a/funasr/models/encoder/chunk_encoder_modules/normalization.py b/funasr/modules/normalization.py
similarity index 100%
rename from funasr/models/encoder/chunk_encoder_modules/normalization.py
rename to funasr/modules/normalization.py
diff --git a/funasr/modules/repeat.py b/funasr/modules/repeat.py
index a3d2676..7241dd9 100644
--- a/funasr/modules/repeat.py
+++ b/funasr/modules/repeat.py
@@ -6,6 +6,8 @@
 
 """Repeat the same layer definition."""
 
+from typing import Dict, List, Optional
+
 import torch
 
 
@@ -31,3 +33,93 @@
 
     """
     return MultiSequential(*[fn(n) for n in range(N)])
+
+
+class MultiBlocks(torch.nn.Module):
+    """MultiBlocks definition.
+    Args:
+        block_list: Individual blocks of the encoder architecture.
+        output_size: Architecture output size.
+        norm_class: Normalization module class.
+        norm_args: Normalization module arguments.
+    """
+
+    def __init__(
+        self,
+        block_list: List[torch.nn.Module],
+        output_size: int,
+        norm_class: torch.nn.Module = torch.nn.LayerNorm,
+        norm_args: Optional[Dict] = None,
+    ) -> None:
+        """Construct a MultiBlocks object."""
+        super().__init__()
+
+        self.blocks = torch.nn.ModuleList(block_list)
+        self.norm_blocks = norm_class(output_size, **norm_args)
+
+        self.num_blocks = len(block_list)
+
+    def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
+        """Initialize/Reset encoder streaming cache.
+        Args:
+            left_context: Number of left frames during chunk-by-chunk inference.
+            device: Device to use for cache tensor.
+        """
+        for idx in range(self.num_blocks):
+            self.blocks[idx].reset_streaming_cache(left_context, device)
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        pos_enc: torch.Tensor,
+        mask: torch.Tensor,
+        chunk_mask: Optional[torch.Tensor] = None,
+    ) -> torch.Tensor:
+        """Forward each block of the encoder architecture.
+        Args:
+            x: MultiBlocks input sequences. (B, T, D_block_1)
+            pos_enc: Positional embedding sequences.
+            mask: Source mask. (B, T)
+            chunk_mask: Chunk mask. (T_2, T_2)
+        Returns:
+            x: Output sequences. (B, T, D_block_N)
+        """
+        for block_index, block in enumerate(self.blocks):
+            x, mask, pos_enc = block(x, pos_enc, mask, chunk_mask=chunk_mask)
+
+        x = self.norm_blocks(x)
+
+        return x
+
+    def chunk_forward(
+        self,
+        x: torch.Tensor,
+        pos_enc: torch.Tensor,
+        mask: torch.Tensor,
+        chunk_size: int = 0,
+        left_context: int = 0,
+        right_context: int = 0,
+    ) -> torch.Tensor:
+        """Forward each block of the encoder architecture.
+        Args:
+            x: MultiBlocks input sequences. (B, T, D_block_1)
+            pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_att)
+            mask: Source mask. (B, T_2)
+            left_context: Number of frames in left context.
+            right_context: Number of frames in right context.
+        Returns:
+            x: MultiBlocks output sequences. (B, T, D_block_N)
+        """
+        for block_idx, block in enumerate(self.blocks):
+            x, pos_enc = block.chunk_forward(
+                x,
+                pos_enc,
+                mask,
+                chunk_size=chunk_size,
+                left_context=left_context,
+                right_context=right_context,
+            )
+
+        x = self.norm_blocks(x)
+
+        return x
diff --git a/funasr/modules/subsampling.py b/funasr/modules/subsampling.py
index d492ccf..623be65 100644
--- a/funasr/modules/subsampling.py
+++ b/funasr/modules/subsampling.py
@@ -11,6 +11,10 @@
 from funasr.modules.embedding import PositionalEncoding
 import logging
 from funasr.modules.streaming_utils.utils import sequence_mask
+from funasr.modules.nets_utils import sub_factor_to_params, pad_to_len
+from typing import Optional, Tuple, Union
+import math
+
 class TooShortUttError(Exception):
     """Raised when the utt is too short for subsampling.
 
@@ -407,3 +411,201 @@
                                                                                   var_dict_tf[name_tf].shape))
         return var_dict_torch_update
 
+class StreamingConvInput(torch.nn.Module):
+    """Streaming ConvInput module definition.
+    Args:
+        input_size: Input size.
+        conv_size: Convolution size.
+        subsampling_factor: Subsampling factor.
+        vgg_like: Whether to use a VGG-like network.
+        output_size: Block output dimension.
+    """
+
+    def __init__(
+        self,
+        input_size: int,
+        conv_size: Union[int, Tuple],
+        subsampling_factor: int = 4,
+        vgg_like: bool = True,
+        output_size: Optional[int] = None,
+    ) -> None:
+        """Construct a ConvInput object."""
+        super().__init__()
+        if vgg_like:
+            if subsampling_factor == 1:
+                conv_size1, conv_size2 = conv_size
+
+                self.conv = torch.nn.Sequential(
+                    torch.nn.Conv2d(1, conv_size1, 3, stride=1, padding=1),
+                    torch.nn.ReLU(),
+                    torch.nn.Conv2d(conv_size1, conv_size1, 3, stride=1, padding=1),
+                    torch.nn.ReLU(),
+                    torch.nn.MaxPool2d((1, 2)),
+                    torch.nn.Conv2d(conv_size1, conv_size2, 3, stride=1, padding=1),
+                    torch.nn.ReLU(),
+                    torch.nn.Conv2d(conv_size2, conv_size2, 3, stride=1, padding=1),
+                    torch.nn.ReLU(),
+                    torch.nn.MaxPool2d((1, 2)),
+                )
+
+                output_proj = conv_size2 * ((input_size // 2) // 2)
+
+                self.subsampling_factor = 1
+
+                self.stride_1 = 1
+
+                self.create_new_mask = self.create_new_vgg_mask
+
+            else:
+                conv_size1, conv_size2 = conv_size
+
+                kernel_1 = int(subsampling_factor / 2)
+
+                self.conv = torch.nn.Sequential(
+                    torch.nn.Conv2d(1, conv_size1, 3, stride=1, padding=1),
+                    torch.nn.ReLU(),
+                    torch.nn.Conv2d(conv_size1, conv_size1, 3, stride=1, padding=1),
+                    torch.nn.ReLU(),
+                    torch.nn.MaxPool2d((kernel_1, 2)),
+                    torch.nn.Conv2d(conv_size1, conv_size2, 3, stride=1, padding=1),
+                    torch.nn.ReLU(),
+                    torch.nn.Conv2d(conv_size2, conv_size2, 3, stride=1, padding=1),
+                    torch.nn.ReLU(),
+                    torch.nn.MaxPool2d((2, 2)),
+                )
+
+                output_proj = conv_size2 * ((input_size // 2) // 2)
+
+                self.subsampling_factor = subsampling_factor
+
+                self.create_new_mask = self.create_new_vgg_mask
+
+                self.stride_1 = kernel_1
+
+        else:
+            if subsampling_factor == 1:
+                self.conv = torch.nn.Sequential(
+                    torch.nn.Conv2d(1, conv_size, 3, [1,2], [1,0]),
+                    torch.nn.ReLU(),
+                    torch.nn.Conv2d(conv_size, conv_size, 3, [1,2], [1,0]),
+                    torch.nn.ReLU(),
+                )
+
+                output_proj = conv_size * (((input_size - 1) // 2 - 1) // 2)
+
+                self.subsampling_factor = subsampling_factor
+                self.kernel_2 = 3
+                self.stride_2 = 1
+
+                self.create_new_mask = self.create_new_conv2d_mask
+
+            else:
+                kernel_2, stride_2, conv_2_output_size = sub_factor_to_params(
+                    subsampling_factor,
+                    input_size,
+                )
+
+                self.conv = torch.nn.Sequential(
+                    torch.nn.Conv2d(1, conv_size, 3, 2),
+                    torch.nn.ReLU(),
+                    torch.nn.Conv2d(conv_size, conv_size, kernel_2, stride_2),
+                    torch.nn.ReLU(),
+                )
+
+                output_proj = conv_size * conv_2_output_size
+
+                self.subsampling_factor = subsampling_factor
+                self.kernel_2 = kernel_2
+                self.stride_2 = stride_2
+
+                self.create_new_mask = self.create_new_conv2d_mask
+
+        self.vgg_like = vgg_like
+        self.min_frame_length = 7
+
+        if output_size is not None:
+            self.output = torch.nn.Linear(output_proj, output_size)
+            self.output_size = output_size
+        else:
+            self.output = None
+            self.output_size = output_proj
+
+    def forward(
+        self, x: torch.Tensor, mask: Optional[torch.Tensor], chunk_size: Optional[torch.Tensor]
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Encode input sequences.
+        Args:
+            x: ConvInput input sequences. (B, T, D_feats)
+            mask: Mask of input sequences. (B, 1, T)
+        Returns:
+            x: ConvInput output sequences. (B, sub(T), D_out)
+            mask: Mask of output sequences. (B, 1, sub(T))
+        """
+        if mask is not None:
+            mask = self.create_new_mask(mask)
+            olens = max(mask.eq(0).sum(1))
+
+        b, t, f = x.size()
+        x = x.unsqueeze(1) # (b. 1. t. f)
+
+        if chunk_size is not None:
+            max_input_length = int(
+                chunk_size * self.subsampling_factor * (math.ceil(float(t) / (chunk_size * self.subsampling_factor) ))
+            )
+            x = map(lambda inputs: pad_to_len(inputs, max_input_length, 1), x)
+            x = list(x)
+            x = torch.stack(x, dim=0)
+            N_chunks = max_input_length // ( chunk_size * self.subsampling_factor)
+            x = x.view(b * N_chunks, 1, chunk_size * self.subsampling_factor, f)
+
+        x = self.conv(x)
+
+        _, c, _, f = x.size()
+        if chunk_size is not None:
+            x = x.transpose(1, 2).contiguous().view(b, -1, c * f)[:,:olens,:]
+        else:
+            x = x.transpose(1, 2).contiguous().view(b, -1, c * f)
+
+        if self.output is not None:
+            x = self.output(x)
+
+        return x, mask[:,:olens][:,:x.size(1)]
+
+    def create_new_vgg_mask(self, mask: torch.Tensor) -> torch.Tensor:
+        """Create a new mask for VGG output sequences.
+        Args:
+            mask: Mask of input sequences. (B, T)
+        Returns:
+            mask: Mask of output sequences. (B, sub(T))
+        """
+        if self.subsampling_factor > 1:
+            vgg1_t_len = mask.size(1) - (mask.size(1) % (self.subsampling_factor // 2 ))
+            mask = mask[:, :vgg1_t_len][:, ::self.subsampling_factor // 2]
+
+            vgg2_t_len = mask.size(1) - (mask.size(1) % 2)
+            mask = mask[:, :vgg2_t_len][:, ::2]
+        else:
+            mask = mask
+
+        return mask
+
+    def create_new_conv2d_mask(self, mask: torch.Tensor) -> torch.Tensor:
+        """Create new conformer mask for Conv2d output sequences.
+        Args:
+            mask: Mask of input sequences. (B, T)
+        Returns:
+            mask: Mask of output sequences. (B, sub(T))
+        """
+        if self.subsampling_factor > 1:
+            return mask[:, :-2:2][:, : -(self.kernel_2 - 1) : self.stride_2]
+        else:
+            return mask
+
+    def get_size_before_subsampling(self, size: int) -> int:
+        """Return the original size before subsampling for a given size.
+        Args:
+            size: Number of frames after subsampling.
+        Returns:
+            : Number of frames before subsampling.
+        """
+        return size * self.subsampling_factor
diff --git a/funasr/tasks/asr_transducer.py b/funasr/tasks/asr_transducer.py
index cae18c1..bb1f996 100644
--- a/funasr/tasks/asr_transducer.py
+++ b/funasr/tasks/asr_transducer.py
@@ -24,7 +24,7 @@
 from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder
 from funasr.models.rnnt_decoder.rnn_decoder import RNNDecoder
 from funasr.models.rnnt_decoder.stateless_decoder import StatelessDecoder
-from funasr.models.encoder.chunk_encoder import ChunkEncoder as Encoder
+from funasr.models.encoder.conformer_encoder import ConformerChunkEncoder
 from funasr.models.e2e_transducer import TransducerModel
 from funasr.models.e2e_transducer_unified import UnifiedTransducerModel
 from funasr.models.joint_network import JointNetwork
@@ -72,9 +72,9 @@
 encoder_choices = ClassChoices(
         "encoder",
         classes=dict(
-                encoder=Encoder,
+                chunk_conformer=ConformerChunkEncoder,
         ),
-        default="encoder",
+        default="chunk_conformer",
 )
 
 decoder_choices = ClassChoices(

--
Gitblit v1.9.1