aky15
2023-04-14 256035b6c1fa6115b6f33972ed243eb43f3e4299
rnnt reorg
8个文件已修改
1 文件已重命名
14个文件已删除
3565 ■■■■■ 已修改文件
funasr/models/e2e_transducer.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/e2e_transducer_unified.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/encoder/chunk_encoder.py 292 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/encoder/chunk_encoder_blocks/__init__.py 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/encoder/chunk_encoder_blocks/branchformer.py 178 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/encoder/chunk_encoder_blocks/conformer.py 198 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/encoder/chunk_encoder_blocks/conv1d.py 221 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/encoder/chunk_encoder_blocks/conv_input.py 222 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/encoder/chunk_encoder_blocks/linear_input.py 52 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/encoder/chunk_encoder_modules/__init__.py 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/encoder/chunk_encoder_modules/attention.py 246 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/encoder/chunk_encoder_modules/convolution.py 196 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/encoder/chunk_encoder_modules/multi_blocks.py 105 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/encoder/chunk_encoder_modules/positional_encoding.py 91 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/encoder/chunk_encoder_utils/building.py 352 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/encoder/chunk_encoder_utils/validation.py 171 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/encoder/conformer_encoder.py 640 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/modules/attention.py 220 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/modules/embedding.py 77 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/modules/normalization.py 补丁 | 查看 | 原始文档 | blame | 历史
funasr/modules/repeat.py 92 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/modules/subsampling.py 202 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tasks/asr_transducer.py 6 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/e2e_transducer.py
@@ -12,7 +12,7 @@
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder
from funasr.models.decoder.abs_decoder import AbsDecoder as AbsAttDecoder
from funasr.models.encoder.chunk_encoder import ChunkEncoder as Encoder
from funasr.models.encoder.conformer_encoder import ConformerChunkEncoder as Encoder
from funasr.models.joint_network import JointNetwork
from funasr.modules.nets_utils import get_transducer_task_io
from funasr.layers.abs_normalize import AbsNormalize
funasr/models/e2e_transducer_unified.py
@@ -11,7 +11,7 @@
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder
from funasr.models.encoder.chunk_encoder import ChunkEncoder as Encoder
from funasr.models.encoder.conformer_encoder import ConformerChunkEncoder as Encoder
from funasr.models.joint_network import JointNetwork
from funasr.modules.nets_utils import get_transducer_task_io
from funasr.layers.abs_normalize import AbsNormalize
funasr/models/encoder/chunk_encoder.py
File was deleted
funasr/models/encoder/chunk_encoder_blocks/__init__.py
funasr/models/encoder/chunk_encoder_blocks/branchformer.py
File was deleted
funasr/models/encoder/chunk_encoder_blocks/conformer.py
File was deleted
funasr/models/encoder/chunk_encoder_blocks/conv1d.py
File was deleted
funasr/models/encoder/chunk_encoder_blocks/conv_input.py
File was deleted
funasr/models/encoder/chunk_encoder_blocks/linear_input.py
File was deleted
funasr/models/encoder/chunk_encoder_modules/__init__.py
funasr/models/encoder/chunk_encoder_modules/attention.py
File was deleted
funasr/models/encoder/chunk_encoder_modules/convolution.py
File was deleted
funasr/models/encoder/chunk_encoder_modules/multi_blocks.py
File was deleted
funasr/models/encoder/chunk_encoder_modules/positional_encoding.py
File was deleted
funasr/models/encoder/chunk_encoder_utils/building.py
File was deleted
funasr/models/encoder/chunk_encoder_utils/validation.py
File was deleted
funasr/models/encoder/conformer_encoder.py
@@ -8,6 +8,7 @@
from typing import Optional
from typing import Tuple
from typing import Union
from typing import Dict
import torch
from torch import nn
@@ -18,6 +19,7 @@
from funasr.modules.attention import (
    MultiHeadedAttention,  # noqa: H301
    RelPositionMultiHeadedAttention,  # noqa: H301
    RelPositionMultiHeadedAttentionChunk,
    LegacyRelPositionMultiHeadedAttention,  # noqa: H301
)
from funasr.modules.embedding import (
@@ -25,16 +27,24 @@
    ScaledPositionalEncoding,  # noqa: H301
    RelPositionalEncoding,  # noqa: H301
    LegacyRelPositionalEncoding,  # noqa: H301
    StreamingRelPositionalEncoding,
)
from funasr.modules.layer_norm import LayerNorm
from funasr.modules.normalization import get_normalization
from funasr.modules.multi_layer_conv import Conv1dLinear
from funasr.modules.multi_layer_conv import MultiLayeredConv1d
from funasr.modules.nets_utils import get_activation
from funasr.modules.nets_utils import make_pad_mask
from funasr.modules.nets_utils import (
    TooShortUttError,
    check_short_utt,
    make_chunk_mask,
    make_source_mask,
)
from funasr.modules.positionwise_feed_forward import (
    PositionwiseFeedForward,  # noqa: H301
)
from funasr.modules.repeat import repeat
from funasr.modules.repeat import repeat, MultiBlocks
from funasr.modules.subsampling import Conv2dSubsampling
from funasr.modules.subsampling import Conv2dSubsampling2
from funasr.modules.subsampling import Conv2dSubsampling6
@@ -42,6 +52,8 @@
from funasr.modules.subsampling import TooShortUttError
from funasr.modules.subsampling import check_short_utt
from funasr.modules.subsampling import Conv2dSubsamplingPad
from funasr.modules.subsampling import StreamingConvInput
class ConvolutionModule(nn.Module):
    """ConvolutionModule in Conformer model.
@@ -275,6 +287,188 @@
            return (x, pos_emb), mask
        return x, mask
class ChunkEncoderLayer(torch.nn.Module):
    """Chunk Conformer module definition.
    Args:
        block_size: Input/output size.
        self_att: Self-attention module instance.
        feed_forward: Feed-forward module instance.
        feed_forward_macaron: Feed-forward module instance for macaron network.
        conv_mod: Convolution module instance.
        norm_class: Normalization module class.
        norm_args: Normalization module arguments.
        dropout_rate: Dropout rate.
    """
    def __init__(
        self,
        block_size: int,
        self_att: torch.nn.Module,
        feed_forward: torch.nn.Module,
        feed_forward_macaron: torch.nn.Module,
        conv_mod: torch.nn.Module,
        norm_class: torch.nn.Module = torch.nn.LayerNorm,
        norm_args: Dict = {},
        dropout_rate: float = 0.0,
    ) -> None:
        """Construct a Conformer object."""
        super().__init__()
        self.self_att = self_att
        self.feed_forward = feed_forward
        self.feed_forward_macaron = feed_forward_macaron
        self.feed_forward_scale = 0.5
        self.conv_mod = conv_mod
        self.norm_feed_forward = norm_class(block_size, **norm_args)
        self.norm_self_att = norm_class(block_size, **norm_args)
        self.norm_macaron = norm_class(block_size, **norm_args)
        self.norm_conv = norm_class(block_size, **norm_args)
        self.norm_final = norm_class(block_size, **norm_args)
        self.dropout = torch.nn.Dropout(dropout_rate)
        self.block_size = block_size
        self.cache = None
    def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
        """Initialize/Reset self-attention and convolution modules cache for streaming.
        Args:
            left_context: Number of left frames during chunk-by-chunk inference.
            device: Device to use for cache tensor.
        """
        self.cache = [
            torch.zeros(
                (1, left_context, self.block_size),
                device=device,
            ),
            torch.zeros(
                (
                    1,
                    self.block_size,
                    self.conv_mod.kernel_size - 1,
                ),
                device=device,
            ),
        ]
    def forward(
        self,
        x: torch.Tensor,
        pos_enc: torch.Tensor,
        mask: torch.Tensor,
        chunk_mask: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Encode input sequences.
        Args:
            x: Conformer input sequences. (B, T, D_block)
            pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
            mask: Source mask. (B, T)
            chunk_mask: Chunk mask. (T_2, T_2)
        Returns:
            x: Conformer output sequences. (B, T, D_block)
            mask: Source mask. (B, T)
            pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
        """
        residual = x
        x = self.norm_macaron(x)
        x = residual + self.feed_forward_scale * self.dropout(
            self.feed_forward_macaron(x)
        )
        residual = x
        x = self.norm_self_att(x)
        x_q = x
        x = residual + self.dropout(
            self.self_att(
                x_q,
                x,
                x,
                pos_enc,
                mask,
                chunk_mask=chunk_mask,
            )
        )
        residual = x
        x = self.norm_conv(x)
        x, _ = self.conv_mod(x)
        x = residual + self.dropout(x)
        residual = x
        x = self.norm_feed_forward(x)
        x = residual + self.feed_forward_scale * self.dropout(self.feed_forward(x))
        x = self.norm_final(x)
        return x, mask, pos_enc
    def chunk_forward(
        self,
        x: torch.Tensor,
        pos_enc: torch.Tensor,
        mask: torch.Tensor,
        chunk_size: int = 16,
        left_context: int = 0,
        right_context: int = 0,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Encode chunk of input sequence.
        Args:
            x: Conformer input sequences. (B, T, D_block)
            pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
            mask: Source mask. (B, T_2)
            left_context: Number of frames in left context.
            right_context: Number of frames in right context.
        Returns:
            x: Conformer output sequences. (B, T, D_block)
            pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
        """
        residual = x
        x = self.norm_macaron(x)
        x = residual + self.feed_forward_scale * self.feed_forward_macaron(x)
        residual = x
        x = self.norm_self_att(x)
        if left_context > 0:
            key = torch.cat([self.cache[0], x], dim=1)
        else:
            key = x
        val = key
        if right_context > 0:
            att_cache = key[:, -(left_context + right_context) : -right_context, :]
        else:
            att_cache = key[:, -left_context:, :]
        x = residual + self.self_att(
            x,
            key,
            val,
            pos_enc,
            mask,
            left_context=left_context,
        )
        residual = x
        x = self.norm_conv(x)
        x, conv_cache = self.conv_mod(
            x, cache=self.cache[1], right_context=right_context
        )
        x = residual + x
        residual = x
        x = self.norm_feed_forward(x)
        x = residual + self.feed_forward_scale * self.feed_forward(x)
        x = self.norm_final(x)
        self.cache = [att_cache, conv_cache]
        return x, pos_enc
class ConformerEncoder(AbsEncoder):
@@ -604,3 +798,447 @@
        if len(intermediate_outs) > 0:
            return (xs_pad, intermediate_outs), olens, None
        return xs_pad, olens, None
class CausalConvolution(torch.nn.Module):
    """ConformerConvolution module definition.
    Args:
        channels: The number of channels.
        kernel_size: Size of the convolving kernel.
        activation: Type of activation function.
        norm_args: Normalization module arguments.
        causal: Whether to use causal convolution (set to True if streaming).
    """
    def __init__(
        self,
        channels: int,
        kernel_size: int,
        activation: torch.nn.Module = torch.nn.ReLU(),
        norm_args: Dict = {},
        causal: bool = False,
    ) -> None:
        """Construct an ConformerConvolution object."""
        super().__init__()
        assert (kernel_size - 1) % 2 == 0
        self.kernel_size = kernel_size
        self.pointwise_conv1 = torch.nn.Conv1d(
            channels,
            2 * channels,
            kernel_size=1,
            stride=1,
            padding=0,
        )
        if causal:
            self.lorder = kernel_size - 1
            padding = 0
        else:
            self.lorder = 0
            padding = (kernel_size - 1) // 2
        self.depthwise_conv = torch.nn.Conv1d(
            channels,
            channels,
            kernel_size,
            stride=1,
            padding=padding,
            groups=channels,
        )
        self.norm = torch.nn.BatchNorm1d(channels, **norm_args)
        self.pointwise_conv2 = torch.nn.Conv1d(
            channels,
            channels,
            kernel_size=1,
            stride=1,
            padding=0,
        )
        self.activation = activation
    def forward(
        self,
        x: torch.Tensor,
        cache: Optional[torch.Tensor] = None,
        right_context: int = 0,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Compute convolution module.
        Args:
            x: ConformerConvolution input sequences. (B, T, D_hidden)
            cache: ConformerConvolution input cache. (1, conv_kernel, D_hidden)
            right_context: Number of frames in right context.
        Returns:
            x: ConformerConvolution output sequences. (B, T, D_hidden)
            cache: ConformerConvolution output cache. (1, conv_kernel, D_hidden)
        """
        x = self.pointwise_conv1(x.transpose(1, 2))
        x = torch.nn.functional.glu(x, dim=1)
        if self.lorder > 0:
            if cache is None:
                x = torch.nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
            else:
                x = torch.cat([cache, x], dim=2)
                if right_context > 0:
                    cache = x[:, :, -(self.lorder + right_context) : -right_context]
                else:
                    cache = x[:, :, -self.lorder :]
        x = self.depthwise_conv(x)
        x = self.activation(self.norm(x))
        x = self.pointwise_conv2(x).transpose(1, 2)
        return x, cache
class ConformerChunkEncoder(torch.nn.Module):
    """Encoder module definition.
    Args:
        input_size: Input size.
        body_conf: Encoder body configuration.
        input_conf: Encoder input configuration.
        main_conf: Encoder main configuration.
    """
    def __init__(
        self,
        input_size: int,
        output_size: int = 256,
        attention_heads: int = 4,
        linear_units: int = 2048,
        num_blocks: int = 6,
        dropout_rate: float = 0.1,
        positional_dropout_rate: float = 0.1,
        attention_dropout_rate: float = 0.0,
        embed_vgg_like: bool = False,
        normalize_before: bool = True,
        concat_after: bool = False,
        positionwise_layer_type: str = "linear",
        positionwise_conv_kernel_size: int = 3,
        macaron_style: bool = False,
        rel_pos_type: str = "legacy",
        pos_enc_layer_type: str = "rel_pos",
        selfattention_layer_type: str = "rel_selfattn",
        activation_type: str = "swish",
        use_cnn_module: bool = True,
        zero_triu: bool = False,
        norm_type: str = "layer_norm",
        cnn_module_kernel: int = 31,
        conv_mod_norm_eps: float = 0.00001,
        conv_mod_norm_momentum: float = 0.1,
        simplified_att_score: bool = False,
        dynamic_chunk_training: bool = False,
        short_chunk_threshold: float = 0.75,
        short_chunk_size: int = 25,
        left_chunk_size: int = 0,
        time_reduction_factor: int = 1,
        unified_model_training: bool = False,
        default_chunk_size: int = 16,
        jitter_range: int = 4,
        subsampling_factor: int = 1,
        **activation_parameters,
    ) -> None:
        """Construct an Encoder object."""
        super().__init__()
        assert check_argument_types()
        self.embed = StreamingConvInput(
            input_size,
            output_size,
            subsampling_factor,
            vgg_like=embed_vgg_like,
            output_size=output_size,
        )
        self.pos_enc = StreamingRelPositionalEncoding(
            output_size,
            positional_dropout_rate,
        )
        activation = get_activation(
            activation_type, **activation_parameters
       )
        pos_wise_args = (
            output_size,
            linear_units,
            positional_dropout_rate,
            activation,
        )
        conv_mod_norm_args = {
            "eps": conv_mod_norm_eps,
            "momentum": conv_mod_norm_momentum,
        }
        conv_mod_args = (
            output_size,
            cnn_module_kernel,
            activation,
            conv_mod_norm_args,
            dynamic_chunk_training or unified_model_training,
        )
        mult_att_args = (
            attention_heads,
            output_size,
            attention_dropout_rate,
            simplified_att_score,
        )
        norm_class, norm_args = get_normalization(
            norm_type,
        )
        fn_modules = []
        for _ in range(num_blocks):
            module = lambda: ChunkEncoderLayer(
                output_size,
                RelPositionMultiHeadedAttentionChunk(*mult_att_args),
                PositionwiseFeedForward(*pos_wise_args),
                PositionwiseFeedForward(*pos_wise_args),
                CausalConvolution(*conv_mod_args),
                norm_class=norm_class,
                norm_args=norm_args,
                dropout_rate=dropout_rate,
            )
            fn_modules.append(module)
        self.encoders = MultiBlocks(
            [fn() for fn in fn_modules],
            output_size,
            norm_class=norm_class,
            norm_args=norm_args,
        )
        self.output_size = output_size
        self.dynamic_chunk_training = dynamic_chunk_training
        self.short_chunk_threshold = short_chunk_threshold
        self.short_chunk_size = short_chunk_size
        self.left_chunk_size = left_chunk_size
        self.unified_model_training = unified_model_training
        self.default_chunk_size = default_chunk_size
        self.jitter_range = jitter_range
        self.time_reduction_factor = time_reduction_factor
    def get_encoder_input_raw_size(self, size: int, hop_length: int) -> int:
        """Return the corresponding number of sample for a given chunk size, in frames.
        Where size is the number of features frames after applying subsampling.
        Args:
            size: Number of frames after subsampling.
            hop_length: Frontend's hop length
        Returns:
            : Number of raw samples
        """
        return self.embed.get_size_before_subsampling(size) * hop_length
    def get_encoder_input_size(self, size: int) -> int:
        """Return the corresponding number of sample for a given chunk size, in frames.
        Where size is the number of features frames after applying subsampling.
        Args:
            size: Number of frames after subsampling.
        Returns:
            : Number of raw samples
        """
        return self.embed.get_size_before_subsampling(size)
    def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
        """Initialize/Reset encoder streaming cache.
        Args:
            left_context: Number of frames in left context.
            device: Device ID.
        """
        return self.encoders.reset_streaming_cache(left_context, device)
    def forward(
        self,
        x: torch.Tensor,
        x_len: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Encode input sequences.
        Args:
            x: Encoder input features. (B, T_in, F)
            x_len: Encoder input features lengths. (B,)
        Returns:
           x: Encoder outputs. (B, T_out, D_enc)
           x_len: Encoder outputs lenghts. (B,)
        """
        short_status, limit_size = check_short_utt(
            self.embed.subsampling_factor, x.size(1)
        )
        if short_status:
            raise TooShortUttError(
                f"has {x.size(1)} frames and is too short for subsampling "
                + f"(it needs more than {limit_size} frames), return empty results",
                x.size(1),
                limit_size,
            )
        mask = make_source_mask(x_len)
        if self.unified_model_training:
            chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item()
            x, mask = self.embed(x, mask, chunk_size)
            pos_enc = self.pos_enc(x)
            chunk_mask = make_chunk_mask(
                x.size(1),
                chunk_size,
                left_chunk_size=self.left_chunk_size,
                device=x.device,
            )
            x_utt = self.encoders(
                x,
                pos_enc,
                mask,
                chunk_mask=None,
            )
            x_chunk = self.encoders(
                x,
                pos_enc,
                mask,
                chunk_mask=chunk_mask,
            )
            olens = mask.eq(0).sum(1)
            if self.time_reduction_factor > 1:
                x_utt = x_utt[:,::self.time_reduction_factor,:]
                x_chunk = x_chunk[:,::self.time_reduction_factor,:]
                olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
            return x_utt, x_chunk, olens
        elif self.dynamic_chunk_training:
            max_len = x.size(1)
            chunk_size = torch.randint(1, max_len, (1,)).item()
            if chunk_size > (max_len * self.short_chunk_threshold):
                chunk_size = max_len
            else:
                chunk_size = (chunk_size % self.short_chunk_size) + 1
            x, mask = self.embed(x, mask, chunk_size)
            pos_enc = self.pos_enc(x)
            chunk_mask = make_chunk_mask(
                x.size(1),
                chunk_size,
                left_chunk_size=self.left_chunk_size,
                device=x.device,
            )
        else:
            x, mask = self.embed(x, mask, None)
            pos_enc = self.pos_enc(x)
            chunk_mask = None
        x = self.encoders(
            x,
            pos_enc,
            mask,
            chunk_mask=chunk_mask,
        )
        olens = mask.eq(0).sum(1)
        if self.time_reduction_factor > 1:
            x = x[:,::self.time_reduction_factor,:]
            olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
        return x, olens
    def simu_chunk_forward(
        self,
        x: torch.Tensor,
        x_len: torch.Tensor,
        chunk_size: int = 16,
        left_context: int = 32,
        right_context: int = 0,
    ) -> torch.Tensor:
        short_status, limit_size = check_short_utt(
            self.embed.subsampling_factor, x.size(1)
        )
        if short_status:
            raise TooShortUttError(
                f"has {x.size(1)} frames and is too short for subsampling "
                + f"(it needs more than {limit_size} frames), return empty results",
                x.size(1),
                limit_size,
            )
        mask = make_source_mask(x_len)
        x, mask = self.embed(x, mask, chunk_size)
        pos_enc = self.pos_enc(x)
        chunk_mask = make_chunk_mask(
            x.size(1),
            chunk_size,
            left_chunk_size=self.left_chunk_size,
            device=x.device,
        )
        x = self.encoders(
            x,
            pos_enc,
            mask,
            chunk_mask=chunk_mask,
        )
        olens = mask.eq(0).sum(1)
        if self.time_reduction_factor > 1:
            x = x[:,::self.time_reduction_factor,:]
        return x
    def chunk_forward(
        self,
        x: torch.Tensor,
        x_len: torch.Tensor,
        processed_frames: torch.tensor,
        chunk_size: int = 16,
        left_context: int = 32,
        right_context: int = 0,
    ) -> torch.Tensor:
        """Encode input sequences as chunks.
        Args:
            x: Encoder input features. (1, T_in, F)
            x_len: Encoder input features lengths. (1,)
            processed_frames: Number of frames already seen.
            left_context: Number of frames in left context.
            right_context: Number of frames in right context.
        Returns:
           x: Encoder outputs. (B, T_out, D_enc)
        """
        mask = make_source_mask(x_len)
        x, mask = self.embed(x, mask, None)
        if left_context > 0:
            processed_mask = (
                torch.arange(left_context, device=x.device)
                .view(1, left_context)
                .flip(1)
            )
            processed_mask = processed_mask >= processed_frames
            mask = torch.cat([processed_mask, mask], dim=1)
        pos_enc = self.pos_enc(x, left_context=left_context)
        x = self.encoders.chunk_forward(
            x,
            pos_enc,
            mask,
            chunk_size=chunk_size,
            left_context=left_context,
            right_context=right_context,
        )
        if right_context > 0:
            x = x[:, 0:-right_context, :]
        if self.time_reduction_factor > 1:
            x = x[:,::self.time_reduction_factor,:]
        return x
funasr/modules/attention.py
@@ -11,7 +11,7 @@
import numpy
import torch
from torch import nn
from typing import Optional, Tuple
class MultiHeadedAttention(nn.Module):
    """Multi-Head Attention layer.
@@ -741,3 +741,221 @@
        scores = torch.matmul(q_h, k_h.transpose(-2, -1))
        att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
        return att_outs
class RelPositionMultiHeadedAttentionChunk(torch.nn.Module):
    """RelPositionMultiHeadedAttention definition.
    Args:
        num_heads: Number of attention heads.
        embed_size: Embedding size.
        dropout_rate: Dropout rate.
    """
    def __init__(
        self,
        num_heads: int,
        embed_size: int,
        dropout_rate: float = 0.0,
        simplified_attention_score: bool = False,
    ) -> None:
        """Construct an MultiHeadedAttention object."""
        super().__init__()
        self.d_k = embed_size // num_heads
        self.num_heads = num_heads
        assert self.d_k * num_heads == embed_size, (
            "embed_size (%d) must be divisible by num_heads (%d)",
            (embed_size, num_heads),
        )
        self.linear_q = torch.nn.Linear(embed_size, embed_size)
        self.linear_k = torch.nn.Linear(embed_size, embed_size)
        self.linear_v = torch.nn.Linear(embed_size, embed_size)
        self.linear_out = torch.nn.Linear(embed_size, embed_size)
        if simplified_attention_score:
            self.linear_pos = torch.nn.Linear(embed_size, num_heads)
            self.compute_att_score = self.compute_simplified_attention_score
        else:
            self.linear_pos = torch.nn.Linear(embed_size, embed_size, bias=False)
            self.pos_bias_u = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k))
            self.pos_bias_v = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k))
            torch.nn.init.xavier_uniform_(self.pos_bias_u)
            torch.nn.init.xavier_uniform_(self.pos_bias_v)
            self.compute_att_score = self.compute_attention_score
        self.dropout = torch.nn.Dropout(p=dropout_rate)
        self.attn = None
    def rel_shift(self, x: torch.Tensor, left_context: int = 0) -> torch.Tensor:
        """Compute relative positional encoding.
        Args:
            x: Input sequence. (B, H, T_1, 2 * T_1 - 1)
            left_context: Number of frames in left context.
        Returns:
            x: Output sequence. (B, H, T_1, T_2)
        """
        batch_size, n_heads, time1, n = x.shape
        time2 = time1 + left_context
        batch_stride, n_heads_stride, time1_stride, n_stride = x.stride()
        return x.as_strided(
            (batch_size, n_heads, time1, time2),
            (batch_stride, n_heads_stride, time1_stride - n_stride, n_stride),
            storage_offset=(n_stride * (time1 - 1)),
        )
    def compute_simplified_attention_score(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        pos_enc: torch.Tensor,
        left_context: int = 0,
    ) -> torch.Tensor:
        """Simplified attention score computation.
        Reference: https://github.com/k2-fsa/icefall/pull/458
        Args:
            query: Transformed query tensor. (B, H, T_1, d_k)
            key: Transformed key tensor. (B, H, T_2, d_k)
            pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
            left_context: Number of frames in left context.
        Returns:
            : Attention score. (B, H, T_1, T_2)
        """
        pos_enc = self.linear_pos(pos_enc)
        matrix_ac = torch.matmul(query, key.transpose(2, 3))
        matrix_bd = self.rel_shift(
            pos_enc.transpose(1, 2).unsqueeze(2).repeat(1, 1, query.size(2), 1),
            left_context=left_context,
        )
        return (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
    def compute_attention_score(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        pos_enc: torch.Tensor,
        left_context: int = 0,
    ) -> torch.Tensor:
        """Attention score computation.
        Args:
            query: Transformed query tensor. (B, H, T_1, d_k)
            key: Transformed key tensor. (B, H, T_2, d_k)
            pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
            left_context: Number of frames in left context.
        Returns:
            : Attention score. (B, H, T_1, T_2)
        """
        p = self.linear_pos(pos_enc).view(pos_enc.size(0), -1, self.num_heads, self.d_k)
        query = query.transpose(1, 2)
        q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2)
        q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2)
        matrix_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1))
        matrix_bd = torch.matmul(q_with_bias_v, p.permute(0, 2, 3, 1))
        matrix_bd = self.rel_shift(matrix_bd, left_context=left_context)
        return (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
    def forward_qkv(
        self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Transform query, key and value.
        Args:
            query: Query tensor. (B, T_1, size)
            key: Key tensor. (B, T_2, size)
            v: Value tensor. (B, T_2, size)
        Returns:
            q: Transformed query tensor. (B, H, T_1, d_k)
            k: Transformed key tensor. (B, H, T_2, d_k)
            v: Transformed value tensor. (B, H, T_2, d_k)
        """
        n_batch = query.size(0)
        q = (
            self.linear_q(query)
            .view(n_batch, -1, self.num_heads, self.d_k)
            .transpose(1, 2)
        )
        k = (
            self.linear_k(key)
            .view(n_batch, -1, self.num_heads, self.d_k)
            .transpose(1, 2)
        )
        v = (
            self.linear_v(value)
            .view(n_batch, -1, self.num_heads, self.d_k)
            .transpose(1, 2)
        )
        return q, k, v
    def forward_attention(
        self,
        value: torch.Tensor,
        scores: torch.Tensor,
        mask: torch.Tensor,
        chunk_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """Compute attention context vector.
        Args:
            value: Transformed value. (B, H, T_2, d_k)
            scores: Attention score. (B, H, T_1, T_2)
            mask: Source mask. (B, T_2)
            chunk_mask: Chunk mask. (T_1, T_1)
        Returns:
           attn_output: Transformed value weighted by attention score. (B, T_1, H * d_k)
        """
        batch_size = scores.size(0)
        mask = mask.unsqueeze(1).unsqueeze(2)
        if chunk_mask is not None:
            mask = chunk_mask.unsqueeze(0).unsqueeze(1) | mask
        scores = scores.masked_fill(mask, float("-inf"))
        self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0)
        attn_output = self.dropout(self.attn)
        attn_output = torch.matmul(attn_output, value)
        attn_output = self.linear_out(
            attn_output.transpose(1, 2)
            .contiguous()
            .view(batch_size, -1, self.num_heads * self.d_k)
        )
        return attn_output
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        pos_enc: torch.Tensor,
        mask: torch.Tensor,
        chunk_mask: Optional[torch.Tensor] = None,
        left_context: int = 0,
    ) -> torch.Tensor:
        """Compute scaled dot product attention with rel. positional encoding.
        Args:
            query: Query tensor. (B, T_1, size)
            key: Key tensor. (B, T_2, size)
            value: Value tensor. (B, T_2, size)
            pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
            mask: Source mask. (B, T_2)
            chunk_mask: Chunk mask. (T_1, T_1)
            left_context: Number of frames in left context.
        Returns:
            : Output tensor. (B, T_1, H * d_k)
        """
        q, k, v = self.forward_qkv(query, key, value)
        scores = self.compute_att_score(q, k, pos_enc, left_context=left_context)
        return self.forward_attention(v, scores, mask, chunk_mask=chunk_mask)
funasr/modules/embedding.py
@@ -423,4 +423,79 @@
        outputs = F.pad(outputs, (pad_left, pad_right))
        outputs = outputs.transpose(1,2)
        return outputs
class StreamingRelPositionalEncoding(torch.nn.Module):
    """Relative positional encoding.
    Args:
        size: Module size.
        max_len: Maximum input length.
        dropout_rate: Dropout rate.
    """
    def __init__(
        self, size: int, dropout_rate: float = 0.0, max_len: int = 5000
    ) -> None:
        """Construct a RelativePositionalEncoding object."""
        super().__init__()
        self.size = size
        self.pe = None
        self.dropout = torch.nn.Dropout(p=dropout_rate)
        self.extend_pe(torch.tensor(0.0).expand(1, max_len))
        self._register_load_state_dict_pre_hook(_pre_hook)
    def extend_pe(self, x: torch.Tensor, left_context: int = 0) -> None:
        """Reset positional encoding.
        Args:
            x: Input sequences. (B, T, ?)
            left_context: Number of frames in left context.
        """
        time1 = x.size(1) + left_context
        if self.pe is not None:
            if self.pe.size(1) >= time1 * 2 - 1:
                if self.pe.dtype != x.dtype or self.pe.device != x.device:
                    self.pe = self.pe.to(device=x.device, dtype=x.dtype)
                return
        pe_positive = torch.zeros(time1, self.size)
        pe_negative = torch.zeros(time1, self.size)
        position = torch.arange(0, time1, dtype=torch.float32).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, self.size, 2, dtype=torch.float32)
            * -(math.log(10000.0) / self.size)
        )
        pe_positive[:, 0::2] = torch.sin(position * div_term)
        pe_positive[:, 1::2] = torch.cos(position * div_term)
        pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
        pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
        pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
        pe_negative = pe_negative[1:].unsqueeze(0)
        self.pe = torch.cat([pe_positive, pe_negative], dim=1).to(
            dtype=x.dtype, device=x.device
        )
    def forward(self, x: torch.Tensor, left_context: int = 0) -> torch.Tensor:
        """Compute positional encoding.
        Args:
            x: Input sequences. (B, T, ?)
            left_context: Number of frames in left context.
        Returns:
            pos_enc: Positional embedding sequences. (B, 2 * (T - 1), ?)
        """
        self.extend_pe(x, left_context=left_context)
        time1 = x.size(1) + left_context
        pos_enc = self.pe[
            :, self.pe.size(1) // 2 - time1 + 1 : self.pe.size(1) // 2 + x.size(1)
        ]
        pos_enc = self.dropout(pos_enc)
        return pos_enc
funasr/modules/normalization.py
funasr/modules/repeat.py
@@ -6,6 +6,8 @@
"""Repeat the same layer definition."""
from typing import Dict, List, Optional
import torch
@@ -31,3 +33,93 @@
    """
    return MultiSequential(*[fn(n) for n in range(N)])
class MultiBlocks(torch.nn.Module):
    """MultiBlocks definition.
    Args:
        block_list: Individual blocks of the encoder architecture.
        output_size: Architecture output size.
        norm_class: Normalization module class.
        norm_args: Normalization module arguments.
    """
    def __init__(
        self,
        block_list: List[torch.nn.Module],
        output_size: int,
        norm_class: torch.nn.Module = torch.nn.LayerNorm,
        norm_args: Optional[Dict] = None,
    ) -> None:
        """Construct a MultiBlocks object."""
        super().__init__()
        self.blocks = torch.nn.ModuleList(block_list)
        self.norm_blocks = norm_class(output_size, **norm_args)
        self.num_blocks = len(block_list)
    def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
        """Initialize/Reset encoder streaming cache.
        Args:
            left_context: Number of left frames during chunk-by-chunk inference.
            device: Device to use for cache tensor.
        """
        for idx in range(self.num_blocks):
            self.blocks[idx].reset_streaming_cache(left_context, device)
    def forward(
        self,
        x: torch.Tensor,
        pos_enc: torch.Tensor,
        mask: torch.Tensor,
        chunk_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """Forward each block of the encoder architecture.
        Args:
            x: MultiBlocks input sequences. (B, T, D_block_1)
            pos_enc: Positional embedding sequences.
            mask: Source mask. (B, T)
            chunk_mask: Chunk mask. (T_2, T_2)
        Returns:
            x: Output sequences. (B, T, D_block_N)
        """
        for block_index, block in enumerate(self.blocks):
            x, mask, pos_enc = block(x, pos_enc, mask, chunk_mask=chunk_mask)
        x = self.norm_blocks(x)
        return x
    def chunk_forward(
        self,
        x: torch.Tensor,
        pos_enc: torch.Tensor,
        mask: torch.Tensor,
        chunk_size: int = 0,
        left_context: int = 0,
        right_context: int = 0,
    ) -> torch.Tensor:
        """Forward each block of the encoder architecture.
        Args:
            x: MultiBlocks input sequences. (B, T, D_block_1)
            pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_att)
            mask: Source mask. (B, T_2)
            left_context: Number of frames in left context.
            right_context: Number of frames in right context.
        Returns:
            x: MultiBlocks output sequences. (B, T, D_block_N)
        """
        for block_idx, block in enumerate(self.blocks):
            x, pos_enc = block.chunk_forward(
                x,
                pos_enc,
                mask,
                chunk_size=chunk_size,
                left_context=left_context,
                right_context=right_context,
            )
        x = self.norm_blocks(x)
        return x
funasr/modules/subsampling.py
@@ -11,6 +11,10 @@
from funasr.modules.embedding import PositionalEncoding
import logging
from funasr.modules.streaming_utils.utils import sequence_mask
from funasr.modules.nets_utils import sub_factor_to_params, pad_to_len
from typing import Optional, Tuple, Union
import math
class TooShortUttError(Exception):
    """Raised when the utt is too short for subsampling.
@@ -407,3 +411,201 @@
                                                                                  var_dict_tf[name_tf].shape))
        return var_dict_torch_update
class StreamingConvInput(torch.nn.Module):
    """Streaming ConvInput module definition.
    Args:
        input_size: Input size.
        conv_size: Convolution size.
        subsampling_factor: Subsampling factor.
        vgg_like: Whether to use a VGG-like network.
        output_size: Block output dimension.
    """
    def __init__(
        self,
        input_size: int,
        conv_size: Union[int, Tuple],
        subsampling_factor: int = 4,
        vgg_like: bool = True,
        output_size: Optional[int] = None,
    ) -> None:
        """Construct a ConvInput object."""
        super().__init__()
        if vgg_like:
            if subsampling_factor == 1:
                conv_size1, conv_size2 = conv_size
                self.conv = torch.nn.Sequential(
                    torch.nn.Conv2d(1, conv_size1, 3, stride=1, padding=1),
                    torch.nn.ReLU(),
                    torch.nn.Conv2d(conv_size1, conv_size1, 3, stride=1, padding=1),
                    torch.nn.ReLU(),
                    torch.nn.MaxPool2d((1, 2)),
                    torch.nn.Conv2d(conv_size1, conv_size2, 3, stride=1, padding=1),
                    torch.nn.ReLU(),
                    torch.nn.Conv2d(conv_size2, conv_size2, 3, stride=1, padding=1),
                    torch.nn.ReLU(),
                    torch.nn.MaxPool2d((1, 2)),
                )
                output_proj = conv_size2 * ((input_size // 2) // 2)
                self.subsampling_factor = 1
                self.stride_1 = 1
                self.create_new_mask = self.create_new_vgg_mask
            else:
                conv_size1, conv_size2 = conv_size
                kernel_1 = int(subsampling_factor / 2)
                self.conv = torch.nn.Sequential(
                    torch.nn.Conv2d(1, conv_size1, 3, stride=1, padding=1),
                    torch.nn.ReLU(),
                    torch.nn.Conv2d(conv_size1, conv_size1, 3, stride=1, padding=1),
                    torch.nn.ReLU(),
                    torch.nn.MaxPool2d((kernel_1, 2)),
                    torch.nn.Conv2d(conv_size1, conv_size2, 3, stride=1, padding=1),
                    torch.nn.ReLU(),
                    torch.nn.Conv2d(conv_size2, conv_size2, 3, stride=1, padding=1),
                    torch.nn.ReLU(),
                    torch.nn.MaxPool2d((2, 2)),
                )
                output_proj = conv_size2 * ((input_size // 2) // 2)
                self.subsampling_factor = subsampling_factor
                self.create_new_mask = self.create_new_vgg_mask
                self.stride_1 = kernel_1
        else:
            if subsampling_factor == 1:
                self.conv = torch.nn.Sequential(
                    torch.nn.Conv2d(1, conv_size, 3, [1,2], [1,0]),
                    torch.nn.ReLU(),
                    torch.nn.Conv2d(conv_size, conv_size, 3, [1,2], [1,0]),
                    torch.nn.ReLU(),
                )
                output_proj = conv_size * (((input_size - 1) // 2 - 1) // 2)
                self.subsampling_factor = subsampling_factor
                self.kernel_2 = 3
                self.stride_2 = 1
                self.create_new_mask = self.create_new_conv2d_mask
            else:
                kernel_2, stride_2, conv_2_output_size = sub_factor_to_params(
                    subsampling_factor,
                    input_size,
                )
                self.conv = torch.nn.Sequential(
                    torch.nn.Conv2d(1, conv_size, 3, 2),
                    torch.nn.ReLU(),
                    torch.nn.Conv2d(conv_size, conv_size, kernel_2, stride_2),
                    torch.nn.ReLU(),
                )
                output_proj = conv_size * conv_2_output_size
                self.subsampling_factor = subsampling_factor
                self.kernel_2 = kernel_2
                self.stride_2 = stride_2
                self.create_new_mask = self.create_new_conv2d_mask
        self.vgg_like = vgg_like
        self.min_frame_length = 7
        if output_size is not None:
            self.output = torch.nn.Linear(output_proj, output_size)
            self.output_size = output_size
        else:
            self.output = None
            self.output_size = output_proj
    def forward(
        self, x: torch.Tensor, mask: Optional[torch.Tensor], chunk_size: Optional[torch.Tensor]
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Encode input sequences.
        Args:
            x: ConvInput input sequences. (B, T, D_feats)
            mask: Mask of input sequences. (B, 1, T)
        Returns:
            x: ConvInput output sequences. (B, sub(T), D_out)
            mask: Mask of output sequences. (B, 1, sub(T))
        """
        if mask is not None:
            mask = self.create_new_mask(mask)
            olens = max(mask.eq(0).sum(1))
        b, t, f = x.size()
        x = x.unsqueeze(1) # (b. 1. t. f)
        if chunk_size is not None:
            max_input_length = int(
                chunk_size * self.subsampling_factor * (math.ceil(float(t) / (chunk_size * self.subsampling_factor) ))
            )
            x = map(lambda inputs: pad_to_len(inputs, max_input_length, 1), x)
            x = list(x)
            x = torch.stack(x, dim=0)
            N_chunks = max_input_length // ( chunk_size * self.subsampling_factor)
            x = x.view(b * N_chunks, 1, chunk_size * self.subsampling_factor, f)
        x = self.conv(x)
        _, c, _, f = x.size()
        if chunk_size is not None:
            x = x.transpose(1, 2).contiguous().view(b, -1, c * f)[:,:olens,:]
        else:
            x = x.transpose(1, 2).contiguous().view(b, -1, c * f)
        if self.output is not None:
            x = self.output(x)
        return x, mask[:,:olens][:,:x.size(1)]
    def create_new_vgg_mask(self, mask: torch.Tensor) -> torch.Tensor:
        """Create a new mask for VGG output sequences.
        Args:
            mask: Mask of input sequences. (B, T)
        Returns:
            mask: Mask of output sequences. (B, sub(T))
        """
        if self.subsampling_factor > 1:
            vgg1_t_len = mask.size(1) - (mask.size(1) % (self.subsampling_factor // 2 ))
            mask = mask[:, :vgg1_t_len][:, ::self.subsampling_factor // 2]
            vgg2_t_len = mask.size(1) - (mask.size(1) % 2)
            mask = mask[:, :vgg2_t_len][:, ::2]
        else:
            mask = mask
        return mask
    def create_new_conv2d_mask(self, mask: torch.Tensor) -> torch.Tensor:
        """Create new conformer mask for Conv2d output sequences.
        Args:
            mask: Mask of input sequences. (B, T)
        Returns:
            mask: Mask of output sequences. (B, sub(T))
        """
        if self.subsampling_factor > 1:
            return mask[:, :-2:2][:, : -(self.kernel_2 - 1) : self.stride_2]
        else:
            return mask
    def get_size_before_subsampling(self, size: int) -> int:
        """Return the original size before subsampling for a given size.
        Args:
            size: Number of frames after subsampling.
        Returns:
            : Number of frames before subsampling.
        """
        return size * self.subsampling_factor
funasr/tasks/asr_transducer.py
@@ -24,7 +24,7 @@
from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder
from funasr.models.rnnt_decoder.rnn_decoder import RNNDecoder
from funasr.models.rnnt_decoder.stateless_decoder import StatelessDecoder
from funasr.models.encoder.chunk_encoder import ChunkEncoder as Encoder
from funasr.models.encoder.conformer_encoder import ConformerChunkEncoder
from funasr.models.e2e_transducer import TransducerModel
from funasr.models.e2e_transducer_unified import UnifiedTransducerModel
from funasr.models.joint_network import JointNetwork
@@ -72,9 +72,9 @@
encoder_choices = ClassChoices(
        "encoder",
        classes=dict(
                encoder=Encoder,
                chunk_conformer=ConformerChunkEncoder,
        ),
        default="encoder",
        default="chunk_conformer",
)
decoder_choices = ClassChoices(