From d79287c37e4e7ae2694a992cbbfb03a5ca4f7670 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 20 二月 2024 14:05:58 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR merge

---
 funasr/models/conformer/encoder.py |  666 +++++++++++++++++++++++++++++++++++++++++++++++++++++++
 1 files changed, 666 insertions(+), 0 deletions(-)

diff --git a/funasr/models/conformer/encoder.py b/funasr/models/conformer/encoder.py
index 1ca437d..1d252c2 100644
--- a/funasr/models/conformer/encoder.py
+++ b/funasr/models/conformer/encoder.py
@@ -14,6 +14,7 @@
     MultiHeadedAttention,  # noqa: H301
     RelPositionMultiHeadedAttention,  # noqa: H301
     LegacyRelPositionMultiHeadedAttention,  # noqa: H301
+    RelPositionMultiHeadedAttentionChunk,
 )
 from funasr.models.transformer.embedding import (
     PositionalEncoding,  # noqa: H301
@@ -610,4 +611,669 @@
         if len(intermediate_outs) > 0:
             return (xs_pad, intermediate_outs), olens, None
         return xs_pad, olens, None
+    
 
+class CausalConvolution(torch.nn.Module):
+    """ConformerConvolution module definition.
+    Args:
+        channels: The number of channels.
+        kernel_size: Size of the convolving kernel.
+        activation: Type of activation function.
+        norm_args: Normalization module arguments.
+        causal: Whether to use causal convolution (set to True if streaming).
+    """
+
+    def __init__(
+        self,
+        channels: int,
+        kernel_size: int,
+        activation: torch.nn.Module = torch.nn.ReLU(),
+        norm_args: Dict = {},
+        causal: bool = False,
+    ) -> None:
+        """Construct an ConformerConvolution object."""
+        super().__init__()
+
+        assert (kernel_size - 1) % 2 == 0
+
+        self.kernel_size = kernel_size
+
+        self.pointwise_conv1 = torch.nn.Conv1d(
+            channels,
+            2 * channels,
+            kernel_size=1,
+            stride=1,
+            padding=0,
+        )
+
+        if causal:
+            self.lorder = kernel_size - 1
+            padding = 0
+        else:
+            self.lorder = 0
+            padding = (kernel_size - 1) // 2
+
+        self.depthwise_conv = torch.nn.Conv1d(
+            channels,
+            channels,
+            kernel_size,
+            stride=1,
+            padding=padding,
+            groups=channels,
+        )
+        self.norm = torch.nn.BatchNorm1d(channels, **norm_args)
+        self.pointwise_conv2 = torch.nn.Conv1d(
+            channels,
+            channels,
+            kernel_size=1,
+            stride=1,
+            padding=0,
+        )
+
+        self.activation = activation
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        cache: Optional[torch.Tensor] = None,
+        right_context: int = 0,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Compute convolution module.
+        Args:
+            x: ConformerConvolution input sequences. (B, T, D_hidden)
+            cache: ConformerConvolution input cache. (1, conv_kernel, D_hidden)
+            right_context: Number of frames in right context.
+        Returns:
+            x: ConformerConvolution output sequences. (B, T, D_hidden)
+            cache: ConformerConvolution output cache. (1, conv_kernel, D_hidden)
+        """
+        x = self.pointwise_conv1(x.transpose(1, 2))
+        x = torch.nn.functional.glu(x, dim=1)
+
+        if self.lorder > 0:
+            if cache is None:
+                x = torch.nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
+            else:
+                x = torch.cat([cache, x], dim=2)
+
+                if right_context > 0:
+                    cache = x[:, :, -(self.lorder + right_context) : -right_context]
+                else:
+                    cache = x[:, :, -self.lorder :]
+
+        x = self.depthwise_conv(x)
+        x = self.activation(self.norm(x))
+
+        x = self.pointwise_conv2(x).transpose(1, 2)
+
+        return x, cache
+
+class ChunkEncoderLayer(torch.nn.Module):
+    """Chunk Conformer module definition.
+    Args:
+        block_size: Input/output size.
+        self_att: Self-attention module instance.
+        feed_forward: Feed-forward module instance.
+        feed_forward_macaron: Feed-forward module instance for macaron network.
+        conv_mod: Convolution module instance.
+        norm_class: Normalization module class.
+        norm_args: Normalization module arguments.
+        dropout_rate: Dropout rate.
+    """
+
+    def __init__(
+        self,
+        block_size: int,
+        self_att: torch.nn.Module,
+        feed_forward: torch.nn.Module,
+        feed_forward_macaron: torch.nn.Module,
+        conv_mod: torch.nn.Module,
+        norm_class: torch.nn.Module = LayerNorm,
+        norm_args: Dict = {},
+        dropout_rate: float = 0.0,
+    ) -> None:
+        """Construct a Conformer object."""
+        super().__init__()
+
+        self.self_att = self_att
+
+        self.feed_forward = feed_forward
+        self.feed_forward_macaron = feed_forward_macaron
+        self.feed_forward_scale = 0.5
+
+        self.conv_mod = conv_mod
+
+        self.norm_feed_forward = norm_class(block_size, **norm_args)
+        self.norm_self_att = norm_class(block_size, **norm_args)
+
+        self.norm_macaron = norm_class(block_size, **norm_args)
+        self.norm_conv = norm_class(block_size, **norm_args)
+        self.norm_final = norm_class(block_size, **norm_args)
+
+        self.dropout = torch.nn.Dropout(dropout_rate)
+
+        self.block_size = block_size
+        self.cache = None
+
+    def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
+        """Initialize/Reset self-attention and convolution modules cache for streaming.
+        Args:
+            left_context: Number of left frames during chunk-by-chunk inference.
+            device: Device to use for cache tensor.
+        """
+        self.cache = [
+            torch.zeros(
+                (1, left_context, self.block_size),
+                device=device,
+            ),
+            torch.zeros(
+                (
+                    1,
+                    self.block_size,
+                    self.conv_mod.kernel_size - 1,
+                ),
+                device=device,
+            ),
+        ]
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        pos_enc: torch.Tensor,
+        mask: torch.Tensor,
+        chunk_mask: Optional[torch.Tensor] = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+        """Encode input sequences.
+        Args:
+            x: Conformer input sequences. (B, T, D_block)
+            pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
+            mask: Source mask. (B, T)
+            chunk_mask: Chunk mask. (T_2, T_2)
+        Returns:
+            x: Conformer output sequences. (B, T, D_block)
+            mask: Source mask. (B, T)
+            pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
+        """
+        residual = x
+
+        x = self.norm_macaron(x)
+        x = residual + self.feed_forward_scale * self.dropout(
+            self.feed_forward_macaron(x)
+        )
+
+        residual = x
+        x = self.norm_self_att(x)
+        x_q = x
+        x = residual + self.dropout(
+            self.self_att(
+                x_q,
+                x,
+                x,
+                pos_enc,
+                mask,
+                chunk_mask=chunk_mask,
+            )
+        )
+
+        residual = x
+
+        x = self.norm_conv(x)
+        x, _ = self.conv_mod(x)
+        x = residual + self.dropout(x)
+        residual = x
+
+        x = self.norm_feed_forward(x)
+        x = residual + self.feed_forward_scale * self.dropout(self.feed_forward(x))
+
+        x = self.norm_final(x)
+        return x, mask, pos_enc
+
+    def chunk_forward(
+        self,
+        x: torch.Tensor,
+        pos_enc: torch.Tensor,
+        mask: torch.Tensor,
+        chunk_size: int = 16,
+        left_context: int = 0,
+        right_context: int = 0,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Encode chunk of input sequence.
+        Args:
+            x: Conformer input sequences. (B, T, D_block)
+            pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
+            mask: Source mask. (B, T_2)
+            left_context: Number of frames in left context.
+            right_context: Number of frames in right context.
+        Returns:
+            x: Conformer output sequences. (B, T, D_block)
+            pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
+        """
+        residual = x
+
+        x = self.norm_macaron(x)
+        x = residual + self.feed_forward_scale * self.feed_forward_macaron(x)
+
+        residual = x
+        x = self.norm_self_att(x)
+        if left_context > 0:
+            key = torch.cat([self.cache[0], x], dim=1)
+        else:
+            key = x
+        val = key
+
+        if right_context > 0:
+            att_cache = key[:, -(left_context + right_context) : -right_context, :]
+        else:
+            att_cache = key[:, -left_context:, :]
+        x = residual + self.self_att(
+            x,
+            key,
+            val,
+            pos_enc,
+            mask,
+            left_context=left_context,
+        )
+
+        residual = x
+        x = self.norm_conv(x)
+        x, conv_cache = self.conv_mod(
+            x, cache=self.cache[1], right_context=right_context
+        )
+        x = residual + x
+        residual = x
+
+        x = self.norm_feed_forward(x)
+        x = residual + self.feed_forward_scale * self.feed_forward(x)
+
+        x = self.norm_final(x)
+        self.cache = [att_cache, conv_cache]
+
+        return x, pos_enc
+
+@tables.register("encoder_classes", "ChunkConformerEncoder")
+class ConformerChunkEncoder(torch.nn.Module):
+    """Encoder module definition.
+    Args:
+        input_size: Input size.
+        body_conf: Encoder body configuration.
+        input_conf: Encoder input configuration.
+        main_conf: Encoder main configuration.
+    """
+
+    def __init__(
+        self,
+        input_size: int,
+        output_size: int = 256,
+        attention_heads: int = 4,
+        linear_units: int = 2048,
+        num_blocks: int = 6,
+        dropout_rate: float = 0.1,
+        positional_dropout_rate: float = 0.1,
+        attention_dropout_rate: float = 0.0,
+        embed_vgg_like: bool = False,
+        normalize_before: bool = True,
+        concat_after: bool = False,
+        positionwise_layer_type: str = "linear",
+        positionwise_conv_kernel_size: int = 3,
+        macaron_style: bool = False,
+        rel_pos_type: str = "legacy",
+        pos_enc_layer_type: str = "rel_pos",
+        selfattention_layer_type: str = "rel_selfattn",
+        activation_type: str = "swish",
+        use_cnn_module: bool = True,
+        zero_triu: bool = False,
+        norm_type: str = "layer_norm",
+        cnn_module_kernel: int = 31,
+        conv_mod_norm_eps: float = 0.00001,
+        conv_mod_norm_momentum: float = 0.1,
+        simplified_att_score: bool = False,
+        dynamic_chunk_training: bool = False,
+        short_chunk_threshold: float = 0.75,
+        short_chunk_size: int = 25,
+        left_chunk_size: int = 0,
+        time_reduction_factor: int = 1,
+        unified_model_training: bool = False,
+        default_chunk_size: int = 16,
+        jitter_range: int = 4,
+        subsampling_factor: int = 1,
+    ) -> None:
+        """Construct an Encoder object."""
+        super().__init__()
+
+
+        self.embed = StreamingConvInput(
+            input_size=input_size,
+            conv_size=output_size,
+            subsampling_factor=subsampling_factor,
+            vgg_like=embed_vgg_like,
+            output_size=output_size,
+        )
+
+        self.pos_enc = StreamingRelPositionalEncoding(
+            output_size,
+            positional_dropout_rate,
+        )
+
+        activation = get_activation(
+            activation_type
+       )        
+
+        pos_wise_args = (
+            output_size,
+            linear_units,
+            positional_dropout_rate,
+            activation,
+        )
+
+        conv_mod_norm_args = {
+            "eps": conv_mod_norm_eps,
+            "momentum": conv_mod_norm_momentum,
+        }
+
+        conv_mod_args = (
+            output_size,
+            cnn_module_kernel,
+            activation,
+            conv_mod_norm_args,
+            dynamic_chunk_training or unified_model_training,
+        )
+
+        mult_att_args = (
+            attention_heads,
+            output_size,
+            attention_dropout_rate,
+            simplified_att_score,
+        )
+
+
+        fn_modules = []
+        for _ in range(num_blocks):
+            module = lambda: ChunkEncoderLayer(
+                output_size,
+                RelPositionMultiHeadedAttentionChunk(*mult_att_args),
+                PositionwiseFeedForward(*pos_wise_args),
+                PositionwiseFeedForward(*pos_wise_args),
+                CausalConvolution(*conv_mod_args),
+                dropout_rate=dropout_rate,
+            )
+            fn_modules.append(module)        
+
+        self.encoders = MultiBlocks(
+            [fn() for fn in fn_modules],
+            output_size,
+        )
+
+        self._output_size = output_size
+
+        self.dynamic_chunk_training = dynamic_chunk_training
+        self.short_chunk_threshold = short_chunk_threshold
+        self.short_chunk_size = short_chunk_size
+        self.left_chunk_size = left_chunk_size
+
+        self.unified_model_training = unified_model_training
+        self.default_chunk_size = default_chunk_size
+        self.jitter_range = jitter_range
+
+        self.time_reduction_factor = time_reduction_factor
+
+    def output_size(self) -> int:
+        return self._output_size
+
+    def get_encoder_input_raw_size(self, size: int, hop_length: int) -> int:
+        """Return the corresponding number of sample for a given chunk size, in frames.
+        Where size is the number of features frames after applying subsampling.
+        Args:
+            size: Number of frames after subsampling.
+            hop_length: Frontend's hop length
+        Returns:
+            : Number of raw samples
+        """
+        return self.embed.get_size_before_subsampling(size) * hop_length
+
+    def get_encoder_input_size(self, size: int) -> int:
+        """Return the corresponding number of sample for a given chunk size, in frames.
+        Where size is the number of features frames after applying subsampling.
+        Args:
+            size: Number of frames after subsampling.
+        Returns:
+            : Number of raw samples
+        """
+        return self.embed.get_size_before_subsampling(size)
+
+
+    def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
+        """Initialize/Reset encoder streaming cache.
+        Args:
+            left_context: Number of frames in left context.
+            device: Device ID.
+        """
+        return self.encoders.reset_streaming_cache(left_context, device)
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        x_len: torch.Tensor,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Encode input sequences.
+        Args:
+            x: Encoder input features. (B, T_in, F)
+            x_len: Encoder input features lengths. (B,)
+        Returns:
+           x: Encoder outputs. (B, T_out, D_enc)
+           x_len: Encoder outputs lenghts. (B,)
+        """
+        short_status, limit_size = check_short_utt(
+            self.embed.subsampling_factor, x.size(1)
+        )
+
+        if short_status:
+            raise TooShortUttError(
+                f"has {x.size(1)} frames and is too short for subsampling "
+                + f"(it needs more than {limit_size} frames), return empty results",
+                x.size(1),
+                limit_size,
+            )
+
+        mask = make_source_mask(x_len).to(x.device)
+
+        if self.unified_model_training:
+            if self.training:
+                chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item()
+            else:
+                chunk_size = self.default_chunk_size
+            x, mask = self.embed(x, mask, chunk_size)
+            pos_enc = self.pos_enc(x)
+            chunk_mask = make_chunk_mask(
+                x.size(1),
+                chunk_size,
+                left_chunk_size=self.left_chunk_size,
+                device=x.device,
+            )
+            x_utt = self.encoders(
+                x,
+                pos_enc,
+                mask,
+                chunk_mask=None,
+            )
+            x_chunk = self.encoders(
+                x,
+                pos_enc,
+                mask,
+                chunk_mask=chunk_mask,
+            )
+
+            olens = mask.eq(0).sum(1)
+            if self.time_reduction_factor > 1:
+                x_utt = x_utt[:,::self.time_reduction_factor,:]
+                x_chunk = x_chunk[:,::self.time_reduction_factor,:]
+                olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
+
+            return x_utt, x_chunk, olens
+
+        elif self.dynamic_chunk_training:
+            max_len = x.size(1)
+            if self.training:
+                chunk_size = torch.randint(1, max_len, (1,)).item()
+
+                if chunk_size > (max_len * self.short_chunk_threshold):
+                    chunk_size = max_len
+                else:
+                    chunk_size = (chunk_size % self.short_chunk_size) + 1
+            else:
+                chunk_size = self.default_chunk_size
+
+            x, mask = self.embed(x, mask, chunk_size)
+            pos_enc = self.pos_enc(x)
+
+            chunk_mask = make_chunk_mask(
+                x.size(1),
+                chunk_size,
+                left_chunk_size=self.left_chunk_size,
+                device=x.device,
+            )
+        else:
+            x, mask = self.embed(x, mask, None)
+            pos_enc = self.pos_enc(x)
+            chunk_mask = None
+        x = self.encoders(
+            x,
+            pos_enc,
+            mask,
+            chunk_mask=chunk_mask,
+        )
+
+        olens = mask.eq(0).sum(1)
+        if self.time_reduction_factor > 1:
+            x = x[:,::self.time_reduction_factor,:]
+            olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
+
+        return x, olens, None
+
+    def full_utt_forward(
+        self,
+        x: torch.Tensor,
+        x_len: torch.Tensor,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Encode input sequences.
+        Args:
+            x: Encoder input features. (B, T_in, F)
+            x_len: Encoder input features lengths. (B,)
+        Returns:
+           x: Encoder outputs. (B, T_out, D_enc)
+           x_len: Encoder outputs lenghts. (B,)
+        """
+        short_status, limit_size = check_short_utt(
+            self.embed.subsampling_factor, x.size(1)
+        )
+
+        if short_status:
+            raise TooShortUttError(
+                f"has {x.size(1)} frames and is too short for subsampling "
+                + f"(it needs more than {limit_size} frames), return empty results",
+                x.size(1),
+                limit_size,
+            )
+
+        mask = make_source_mask(x_len).to(x.device)
+        x, mask = self.embed(x, mask, None)
+        pos_enc = self.pos_enc(x)
+        x_utt = self.encoders(
+            x,
+            pos_enc,
+            mask,
+            chunk_mask=None,
+        )
+
+        if self.time_reduction_factor > 1:
+            x_utt = x_utt[:,::self.time_reduction_factor,:]
+        return x_utt
+
+    def simu_chunk_forward(
+        self,
+        x: torch.Tensor,
+        x_len: torch.Tensor,
+        chunk_size: int = 16,
+        left_context: int = 32,
+        right_context: int = 0,
+    ) -> torch.Tensor:
+        short_status, limit_size = check_short_utt(
+            self.embed.subsampling_factor, x.size(1)
+        )
+
+        if short_status:
+            raise TooShortUttError(
+                f"has {x.size(1)} frames and is too short for subsampling "
+                + f"(it needs more than {limit_size} frames), return empty results",
+                x.size(1),
+                limit_size,
+            )
+
+        mask = make_source_mask(x_len)
+
+        x, mask = self.embed(x, mask, chunk_size)
+        pos_enc = self.pos_enc(x)
+        chunk_mask = make_chunk_mask(
+            x.size(1),
+            chunk_size,
+            left_chunk_size=self.left_chunk_size,
+            device=x.device,
+        )
+
+        x = self.encoders(
+            x,
+            pos_enc,
+            mask,
+            chunk_mask=chunk_mask,
+        )
+        olens = mask.eq(0).sum(1)
+        if self.time_reduction_factor > 1:
+            x = x[:,::self.time_reduction_factor,:]
+
+        return x
+
+    def chunk_forward(
+        self,
+        x: torch.Tensor,
+        x_len: torch.Tensor,
+        processed_frames: torch.tensor,
+        chunk_size: int = 16,
+        left_context: int = 32,
+        right_context: int = 0,
+    ) -> torch.Tensor:
+        """Encode input sequences as chunks.
+        Args:
+            x: Encoder input features. (1, T_in, F)
+            x_len: Encoder input features lengths. (1,)
+            processed_frames: Number of frames already seen.
+            left_context: Number of frames in left context.
+            right_context: Number of frames in right context.
+        Returns:
+           x: Encoder outputs. (B, T_out, D_enc)
+        """
+        mask = make_source_mask(x_len)
+        x, mask = self.embed(x, mask, None)
+
+        if left_context > 0:
+            processed_mask = (
+                torch.arange(left_context, device=x.device)
+                .view(1, left_context)
+                .flip(1)
+            )
+            processed_mask = processed_mask >= processed_frames
+            mask = torch.cat([processed_mask, mask], dim=1)
+        pos_enc = self.pos_enc(x, left_context=left_context)
+        x = self.encoders.chunk_forward(
+            x,
+            pos_enc,
+            mask,
+            chunk_size=chunk_size,
+            left_context=left_context,
+            right_context=right_context,
+        )
+
+        if right_context > 0:
+            x = x[:, 0:-right_context, :]
+
+        if self.time_reduction_factor > 1:
+            x = x[:,::self.time_reduction_factor,:]
+        return x

--
Gitblit v1.9.1