kongdeqiang
2026-03-13 28ccfbfc51068a663a80764e14074df5edf2b5ba
funasr/models/conformer/encoder.py
@@ -14,6 +14,7 @@
    MultiHeadedAttention,  # noqa: H301
    RelPositionMultiHeadedAttention,  # noqa: H301
    LegacyRelPositionMultiHeadedAttention,  # noqa: H301
    RelPositionMultiHeadedAttentionChunk,
)
from funasr.models.transformer.embedding import (
    PositionalEncoding,  # noqa: H301
@@ -46,6 +47,7 @@
from funasr.models.transformer.utils.subsampling import Conv2dSubsamplingPad
from funasr.models.transformer.utils.subsampling import StreamingConvInput
from funasr.register import tables
import pdb
class ConvolutionModule(nn.Module):
@@ -145,16 +147,16 @@
    """
    def __init__(
            self,
            size,
            self_attn,
            feed_forward,
            feed_forward_macaron,
            conv_module,
            dropout_rate,
            normalize_before=True,
            concat_after=False,
            stochastic_depth_rate=0.0,
        self,
        size,
        self_attn,
        feed_forward,
        feed_forward_macaron,
        conv_module,
        dropout_rate,
        normalize_before=True,
        concat_after=False,
        stochastic_depth_rate=0.0,
    ):
        """Construct an EncoderLayer object."""
        super(EncoderLayer, self).__init__()
@@ -265,9 +267,7 @@
        residual = x
        if self.normalize_before:
            x = self.norm_ff(x)
        x = residual + stoch_layer_coeff * self.ff_scale * self.dropout(
            self.feed_forward(x)
        )
        x = residual + stoch_layer_coeff * self.ff_scale * self.dropout(self.feed_forward(x))
        if not self.normalize_before:
            x = self.norm_ff(x)
@@ -320,32 +320,32 @@
    """
    def __init__(
            self,
            input_size: int,
            output_size: int = 256,
            attention_heads: int = 4,
            linear_units: int = 2048,
            num_blocks: int = 6,
            dropout_rate: float = 0.1,
            positional_dropout_rate: float = 0.1,
            attention_dropout_rate: float = 0.0,
            input_layer: str = "conv2d",
            normalize_before: bool = True,
            concat_after: bool = False,
            positionwise_layer_type: str = "linear",
            positionwise_conv_kernel_size: int = 3,
            macaron_style: bool = False,
            rel_pos_type: str = "legacy",
            pos_enc_layer_type: str = "rel_pos",
            selfattention_layer_type: str = "rel_selfattn",
            activation_type: str = "swish",
            use_cnn_module: bool = True,
            zero_triu: bool = False,
            cnn_module_kernel: int = 31,
            padding_idx: int = -1,
            interctc_layer_idx: List[int] = [],
            interctc_use_conditioning: bool = False,
            stochastic_depth_rate: Union[float, List[float]] = 0.0,
        self,
        input_size: int,
        output_size: int = 256,
        attention_heads: int = 4,
        linear_units: int = 2048,
        num_blocks: int = 6,
        dropout_rate: float = 0.1,
        positional_dropout_rate: float = 0.1,
        attention_dropout_rate: float = 0.0,
        input_layer: str = "conv2d",
        normalize_before: bool = True,
        concat_after: bool = False,
        positionwise_layer_type: str = "linear",
        positionwise_conv_kernel_size: int = 3,
        macaron_style: bool = False,
        rel_pos_type: str = "legacy",
        pos_enc_layer_type: str = "rel_pos",
        selfattention_layer_type: str = "rel_selfattn",
        activation_type: str = "swish",
        use_cnn_module: bool = True,
        zero_triu: bool = False,
        cnn_module_kernel: int = 31,
        padding_idx: int = -1,
        interctc_layer_idx: List[int] = [],
        interctc_use_conditioning: bool = False,
        stochastic_depth_rate: Union[float, List[float]] = 0.0,
    ):
        super().__init__()
        self._output_size = output_size
@@ -372,9 +372,7 @@
        elif pos_enc_layer_type == "legacy_rel_pos":
            assert selfattention_layer_type == "legacy_rel_selfattn"
            pos_enc_class = LegacyRelPositionalEncoding
            logging.warning(
                "Using legacy_rel_pos and it will be deprecated in the future."
            )
            logging.warning("Using legacy_rel_pos and it will be deprecated in the future.")
        else:
            raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
@@ -431,9 +429,7 @@
                pos_enc_class(output_size, positional_dropout_rate),
            )
        elif input_layer is None:
            self.embed = torch.nn.Sequential(
                pos_enc_class(output_size, positional_dropout_rate)
            )
            self.embed = torch.nn.Sequential(pos_enc_class(output_size, positional_dropout_rate))
        else:
            raise ValueError("unknown input_layer: " + input_layer)
        self.normalize_before = normalize_before
@@ -479,9 +475,7 @@
                output_size,
                attention_dropout_rate,
            )
            logging.warning(
                "Using legacy_rel_selfattn and it will be deprecated in the future."
            )
            logging.warning("Using legacy_rel_selfattn and it will be deprecated in the future.")
        elif selfattention_layer_type == "rel_selfattn":
            assert pos_enc_layer_type == "rel_pos"
            encoder_selfattn_layer = RelPositionMultiHeadedAttention
@@ -533,11 +527,11 @@
        return self._output_size
    def forward(
            self,
            xs_pad: torch.Tensor,
            ilens: torch.Tensor,
            prev_states: torch.Tensor = None,
            ctc: CTC = None,
        self,
        xs_pad: torch.Tensor,
        ilens: torch.Tensor,
        prev_states: torch.Tensor = None,
        ctc: CTC = None,
    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
        """Calculate forward propagation.
@@ -555,11 +549,11 @@
        masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
        if (
                isinstance(self.embed, Conv2dSubsampling)
                or isinstance(self.embed, Conv2dSubsampling2)
                or isinstance(self.embed, Conv2dSubsampling6)
                or isinstance(self.embed, Conv2dSubsampling8)
                or isinstance(self.embed, Conv2dSubsamplingPad)
            isinstance(self.embed, Conv2dSubsampling)
            or isinstance(self.embed, Conv2dSubsampling2)
            or isinstance(self.embed, Conv2dSubsampling6)
            or isinstance(self.embed, Conv2dSubsampling8)
            or isinstance(self.embed, Conv2dSubsamplingPad)
        ):
            short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
            if short_status:
@@ -611,3 +605,656 @@
            return (xs_pad, intermediate_outs), olens, None
        return xs_pad, olens, None
class CausalConvolution(torch.nn.Module):
    """ConformerConvolution module definition.
    Args:
        channels: The number of channels.
        kernel_size: Size of the convolving kernel.
        activation: Type of activation function.
        norm_args: Normalization module arguments.
        causal: Whether to use causal convolution (set to True if streaming).
    """
    def __init__(
        self,
        channels: int,
        kernel_size: int,
        activation: torch.nn.Module = torch.nn.ReLU(),
        norm_args: Dict = {},
        causal: bool = False,
    ) -> None:
        """Construct an ConformerConvolution object."""
        super().__init__()
        assert (kernel_size - 1) % 2 == 0
        self.kernel_size = kernel_size
        self.pointwise_conv1 = torch.nn.Conv1d(
            channels,
            2 * channels,
            kernel_size=1,
            stride=1,
            padding=0,
        )
        if causal:
            self.lorder = kernel_size - 1
            padding = 0
        else:
            self.lorder = 0
            padding = (kernel_size - 1) // 2
        self.depthwise_conv = torch.nn.Conv1d(
            channels,
            channels,
            kernel_size,
            stride=1,
            padding=padding,
            groups=channels,
        )
        self.norm = torch.nn.BatchNorm1d(channels, **norm_args)
        self.pointwise_conv2 = torch.nn.Conv1d(
            channels,
            channels,
            kernel_size=1,
            stride=1,
            padding=0,
        )
        self.activation = activation
    def forward(
        self,
        x: torch.Tensor,
        cache: Optional[torch.Tensor] = None,
        right_context: int = 0,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Compute convolution module.
        Args:
            x: ConformerConvolution input sequences. (B, T, D_hidden)
            cache: ConformerConvolution input cache. (1, conv_kernel, D_hidden)
            right_context: Number of frames in right context.
        Returns:
            x: ConformerConvolution output sequences. (B, T, D_hidden)
            cache: ConformerConvolution output cache. (1, conv_kernel, D_hidden)
        """
        x = self.pointwise_conv1(x.transpose(1, 2))
        x = torch.nn.functional.glu(x, dim=1)
        if self.lorder > 0:
            if cache is None:
                x = torch.nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
            else:
                x = torch.cat([cache, x], dim=2)
                if right_context > 0:
                    cache = x[:, :, -(self.lorder + right_context) : -right_context]
                else:
                    cache = x[:, :, -self.lorder :]
        x = self.depthwise_conv(x)
        x = self.activation(self.norm(x))
        x = self.pointwise_conv2(x).transpose(1, 2)
        return x, cache
class ChunkEncoderLayer(torch.nn.Module):
    """Chunk Conformer module definition.
    Args:
        block_size: Input/output size.
        self_att: Self-attention module instance.
        feed_forward: Feed-forward module instance.
        feed_forward_macaron: Feed-forward module instance for macaron network.
        conv_mod: Convolution module instance.
        norm_class: Normalization module class.
        norm_args: Normalization module arguments.
        dropout_rate: Dropout rate.
    """
    def __init__(
        self,
        block_size: int,
        self_att: torch.nn.Module,
        feed_forward: torch.nn.Module,
        feed_forward_macaron: torch.nn.Module,
        conv_mod: torch.nn.Module,
        norm_class: torch.nn.Module = LayerNorm,
        norm_args: Dict = {},
        dropout_rate: float = 0.0,
    ) -> None:
        """Construct a Conformer object."""
        super().__init__()
        self.self_att = self_att
        self.feed_forward = feed_forward
        self.feed_forward_macaron = feed_forward_macaron
        self.feed_forward_scale = 0.5
        self.conv_mod = conv_mod
        self.norm_feed_forward = norm_class(block_size, **norm_args)
        self.norm_self_att = norm_class(block_size, **norm_args)
        self.norm_macaron = norm_class(block_size, **norm_args)
        self.norm_conv = norm_class(block_size, **norm_args)
        self.norm_final = norm_class(block_size, **norm_args)
        self.dropout = torch.nn.Dropout(dropout_rate)
        self.block_size = block_size
        self.cache = None
    def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
        """Initialize/Reset self-attention and convolution modules cache for streaming.
        Args:
            left_context: Number of left frames during chunk-by-chunk inference.
            device: Device to use for cache tensor.
        """
        self.cache = [
            torch.zeros(
                (1, left_context, self.block_size),
                device=device,
            ),
            torch.zeros(
                (
                    1,
                    self.block_size,
                    self.conv_mod.kernel_size - 1,
                ),
                device=device,
            ),
        ]
    def forward(
        self,
        x: torch.Tensor,
        pos_enc: torch.Tensor,
        mask: torch.Tensor,
        chunk_mask: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Encode input sequences.
        Args:
            x: Conformer input sequences. (B, T, D_block)
            pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
            mask: Source mask. (B, T)
            chunk_mask: Chunk mask. (T_2, T_2)
        Returns:
            x: Conformer output sequences. (B, T, D_block)
            mask: Source mask. (B, T)
            pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
        """
        residual = x
        x = self.norm_macaron(x)
        x = residual + self.feed_forward_scale * self.dropout(self.feed_forward_macaron(x))
        residual = x
        x = self.norm_self_att(x)
        x_q = x
        x = residual + self.dropout(
            self.self_att(
                x_q,
                x,
                x,
                pos_enc,
                mask,
                chunk_mask=chunk_mask,
            )
        )
        residual = x
        x = self.norm_conv(x)
        x, _ = self.conv_mod(x)
        x = residual + self.dropout(x)
        residual = x
        x = self.norm_feed_forward(x)
        x = residual + self.feed_forward_scale * self.dropout(self.feed_forward(x))
        x = self.norm_final(x)
        return x, mask, pos_enc
    def chunk_forward(
        self,
        x: torch.Tensor,
        pos_enc: torch.Tensor,
        mask: torch.Tensor,
        chunk_size: int = 16,
        left_context: int = 0,
        right_context: int = 0,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Encode chunk of input sequence.
        Args:
            x: Conformer input sequences. (B, T, D_block)
            pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
            mask: Source mask. (B, T_2)
            left_context: Number of frames in left context.
            right_context: Number of frames in right context.
        Returns:
            x: Conformer output sequences. (B, T, D_block)
            pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
        """
        residual = x
        x = self.norm_macaron(x)
        x = residual + self.feed_forward_scale * self.feed_forward_macaron(x)
        residual = x
        x = self.norm_self_att(x)
        if left_context > 0:
            key = torch.cat([self.cache[0], x], dim=1)
        else:
            key = x
        val = key
        if right_context > 0:
            att_cache = key[:, -(left_context + right_context) : -right_context, :]
        else:
            att_cache = key[:, -left_context:, :]
        x = residual + self.self_att(
            x,
            key,
            val,
            pos_enc,
            mask,
            left_context=left_context,
        )
        residual = x
        x = self.norm_conv(x)
        x, conv_cache = self.conv_mod(x, cache=self.cache[1], right_context=right_context)
        x = residual + x
        residual = x
        x = self.norm_feed_forward(x)
        x = residual + self.feed_forward_scale * self.feed_forward(x)
        x = self.norm_final(x)
        self.cache = [att_cache, conv_cache]
        return x, pos_enc
@tables.register("encoder_classes", "ChunkConformerEncoder")
class ConformerChunkEncoder(torch.nn.Module):
    """Encoder module definition.
    Args:
        input_size: Input size.
        body_conf: Encoder body configuration.
        input_conf: Encoder input configuration.
        main_conf: Encoder main configuration.
    """
    def __init__(
        self,
        input_size: int,
        output_size: int = 256,
        attention_heads: int = 4,
        linear_units: int = 2048,
        num_blocks: int = 6,
        dropout_rate: float = 0.1,
        positional_dropout_rate: float = 0.1,
        attention_dropout_rate: float = 0.0,
        embed_vgg_like: bool = False,
        normalize_before: bool = True,
        concat_after: bool = False,
        positionwise_layer_type: str = "linear",
        positionwise_conv_kernel_size: int = 3,
        macaron_style: bool = False,
        rel_pos_type: str = "legacy",
        pos_enc_layer_type: str = "rel_pos",
        selfattention_layer_type: str = "rel_selfattn",
        activation_type: str = "swish",
        use_cnn_module: bool = True,
        zero_triu: bool = False,
        norm_type: str = "layer_norm",
        cnn_module_kernel: int = 31,
        conv_mod_norm_eps: float = 0.00001,
        conv_mod_norm_momentum: float = 0.1,
        simplified_att_score: bool = False,
        dynamic_chunk_training: bool = False,
        short_chunk_threshold: float = 0.75,
        short_chunk_size: int = 25,
        left_chunk_size: int = 0,
        time_reduction_factor: int = 1,
        unified_model_training: bool = False,
        default_chunk_size: int = 16,
        jitter_range: int = 4,
        subsampling_factor: int = 1,
    ) -> None:
        """Construct an Encoder object."""
        super().__init__()
        self.embed = StreamingConvInput(
            input_size=input_size,
            conv_size=output_size,
            subsampling_factor=subsampling_factor,
            vgg_like=embed_vgg_like,
            output_size=output_size,
        )
        self.pos_enc = StreamingRelPositionalEncoding(
            output_size,
            positional_dropout_rate,
        )
        activation = get_activation(activation_type)
        pos_wise_args = (
            output_size,
            linear_units,
            positional_dropout_rate,
            activation,
        )
        conv_mod_norm_args = {
            "eps": conv_mod_norm_eps,
            "momentum": conv_mod_norm_momentum,
        }
        conv_mod_args = (
            output_size,
            cnn_module_kernel,
            activation,
            conv_mod_norm_args,
            dynamic_chunk_training or unified_model_training,
        )
        mult_att_args = (
            attention_heads,
            output_size,
            attention_dropout_rate,
            simplified_att_score,
        )
        fn_modules = []
        for _ in range(num_blocks):
            module = lambda: ChunkEncoderLayer(
                output_size,
                RelPositionMultiHeadedAttentionChunk(*mult_att_args),
                PositionwiseFeedForward(*pos_wise_args),
                PositionwiseFeedForward(*pos_wise_args),
                CausalConvolution(*conv_mod_args),
                dropout_rate=dropout_rate,
            )
            fn_modules.append(module)
        self.encoders = MultiBlocks(
            [fn() for fn in fn_modules],
            output_size,
        )
        self._output_size = output_size
        self.dynamic_chunk_training = dynamic_chunk_training
        self.short_chunk_threshold = short_chunk_threshold
        self.short_chunk_size = short_chunk_size
        self.left_chunk_size = left_chunk_size
        self.unified_model_training = unified_model_training
        self.default_chunk_size = default_chunk_size
        self.jitter_range = jitter_range
        self.time_reduction_factor = time_reduction_factor
    def output_size(self) -> int:
        return self._output_size
    def get_encoder_input_raw_size(self, size: int, hop_length: int) -> int:
        """Return the corresponding number of sample for a given chunk size, in frames.
        Where size is the number of features frames after applying subsampling.
        Args:
            size: Number of frames after subsampling.
            hop_length: Frontend's hop length
        Returns:
            : Number of raw samples
        """
        return self.embed.get_size_before_subsampling(size) * hop_length
    def get_encoder_input_size(self, size: int) -> int:
        """Return the corresponding number of sample for a given chunk size, in frames.
        Where size is the number of features frames after applying subsampling.
        Args:
            size: Number of frames after subsampling.
        Returns:
            : Number of raw samples
        """
        return self.embed.get_size_before_subsampling(size)
    def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
        """Initialize/Reset encoder streaming cache.
        Args:
            left_context: Number of frames in left context.
            device: Device ID.
        """
        return self.encoders.reset_streaming_cache(left_context, device)
    def forward(
        self,
        x: torch.Tensor,
        x_len: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Encode input sequences.
        Args:
            x: Encoder input features. (B, T_in, F)
            x_len: Encoder input features lengths. (B,)
        Returns:
           x: Encoder outputs. (B, T_out, D_enc)
           x_len: Encoder outputs lenghts. (B,)
        """
        short_status, limit_size = check_short_utt(self.embed.subsampling_factor, x.size(1))
        if short_status:
            raise TooShortUttError(
                f"has {x.size(1)} frames and is too short for subsampling "
                + f"(it needs more than {limit_size} frames), return empty results",
                x.size(1),
                limit_size,
            )
        mask = make_source_mask(x_len).to(x.device)
        if self.unified_model_training:
            if self.training:
                chunk_size = (
                    self.default_chunk_size
                    + torch.randint(-self.jitter_range, self.jitter_range + 1, (1,)).item()
                )
            else:
                chunk_size = self.default_chunk_size
            x, mask = self.embed(x, mask, chunk_size)
            pos_enc = self.pos_enc(x)
            chunk_mask = make_chunk_mask(
                x.size(1),
                chunk_size,
                left_chunk_size=self.left_chunk_size,
                device=x.device,
            )
            x_utt = self.encoders(
                x,
                pos_enc,
                mask,
                chunk_mask=None,
            )
            x_chunk = self.encoders(
                x,
                pos_enc,
                mask,
                chunk_mask=chunk_mask,
            )
            olens = mask.eq(0).sum(1)
            if self.time_reduction_factor > 1:
                x_utt = x_utt[:, :: self.time_reduction_factor, :]
                x_chunk = x_chunk[:, :: self.time_reduction_factor, :]
                olens = torch.floor_divide(olens - 1, self.time_reduction_factor) + 1
            return x_utt, x_chunk, olens
        elif self.dynamic_chunk_training:
            max_len = x.size(1)
            if self.training:
                chunk_size = torch.randint(1, max_len, (1,)).item()
                if chunk_size > (max_len * self.short_chunk_threshold):
                    chunk_size = max_len
                else:
                    chunk_size = (chunk_size % self.short_chunk_size) + 1
            else:
                chunk_size = self.default_chunk_size
            x, mask = self.embed(x, mask, chunk_size)
            pos_enc = self.pos_enc(x)
            chunk_mask = make_chunk_mask(
                x.size(1),
                chunk_size,
                left_chunk_size=self.left_chunk_size,
                device=x.device,
            )
        else:
            x, mask = self.embed(x, mask, None)
            pos_enc = self.pos_enc(x)
            chunk_mask = None
        x = self.encoders(
            x,
            pos_enc,
            mask,
            chunk_mask=chunk_mask,
        )
        olens = mask.eq(0).sum(1)
        if self.time_reduction_factor > 1:
            x = x[:, :: self.time_reduction_factor, :]
            olens = torch.floor_divide(olens - 1, self.time_reduction_factor) + 1
        return x, olens, None
    def full_utt_forward(
        self,
        x: torch.Tensor,
        x_len: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Encode input sequences.
        Args:
            x: Encoder input features. (B, T_in, F)
            x_len: Encoder input features lengths. (B,)
        Returns:
           x: Encoder outputs. (B, T_out, D_enc)
           x_len: Encoder outputs lenghts. (B,)
        """
        short_status, limit_size = check_short_utt(self.embed.subsampling_factor, x.size(1))
        if short_status:
            raise TooShortUttError(
                f"has {x.size(1)} frames and is too short for subsampling "
                + f"(it needs more than {limit_size} frames), return empty results",
                x.size(1),
                limit_size,
            )
        mask = make_source_mask(x_len).to(x.device)
        x, mask = self.embed(x, mask, None)
        pos_enc = self.pos_enc(x)
        x_utt = self.encoders(
            x,
            pos_enc,
            mask,
            chunk_mask=None,
        )
        if self.time_reduction_factor > 1:
            x_utt = x_utt[:, :: self.time_reduction_factor, :]
        return x_utt
    def simu_chunk_forward(
        self,
        x: torch.Tensor,
        x_len: torch.Tensor,
        chunk_size: int = 16,
        left_context: int = 32,
        right_context: int = 0,
    ) -> torch.Tensor:
        short_status, limit_size = check_short_utt(self.embed.subsampling_factor, x.size(1))
        if short_status:
            raise TooShortUttError(
                f"has {x.size(1)} frames and is too short for subsampling "
                + f"(it needs more than {limit_size} frames), return empty results",
                x.size(1),
                limit_size,
            )
        mask = make_source_mask(x_len)
        x, mask = self.embed(x, mask, chunk_size)
        pos_enc = self.pos_enc(x)
        chunk_mask = make_chunk_mask(
            x.size(1),
            chunk_size,
            left_chunk_size=self.left_chunk_size,
            device=x.device,
        )
        x = self.encoders(
            x,
            pos_enc,
            mask,
            chunk_mask=chunk_mask,
        )
        olens = mask.eq(0).sum(1)
        if self.time_reduction_factor > 1:
            x = x[:, :: self.time_reduction_factor, :]
        return x
    def chunk_forward(
        self,
        x: torch.Tensor,
        x_len: torch.Tensor,
        processed_frames: torch.tensor,
        chunk_size: int = 16,
        left_context: int = 32,
        right_context: int = 0,
    ) -> torch.Tensor:
        """Encode input sequences as chunks.
        Args:
            x: Encoder input features. (1, T_in, F)
            x_len: Encoder input features lengths. (1,)
            processed_frames: Number of frames already seen.
            left_context: Number of frames in left context.
            right_context: Number of frames in right context.
        Returns:
           x: Encoder outputs. (B, T_out, D_enc)
        """
        mask = make_source_mask(x_len)
        x, mask = self.embed(x, mask, None)
        if left_context > 0:
            processed_mask = (
                torch.arange(left_context, device=x.device).view(1, left_context).flip(1)
            )
            processed_mask = processed_mask >= processed_frames
            mask = torch.cat([processed_mask, mask], dim=1)
        pos_enc = self.pos_enc(x, left_context=left_context)
        x = self.encoders.chunk_forward(
            x,
            pos_enc,
            mask,
            chunk_size=chunk_size,
            left_context=left_context,
            right_context=right_context,
        )
        if right_context > 0:
            x = x[:, 0:-right_context, :]
        if self.time_reduction_factor > 1:
            x = x[:, :: self.time_reduction_factor, :]
        return x