hnluo
2023-09-11 9fcb3cc06b4e324f0913d2f61b89becc2baeef1b
funasr/models/encoder/conformer_encoder.py
@@ -12,16 +12,15 @@
import torch
from torch import nn
from typeguard import check_argument_types
from funasr.models.ctc import CTC
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.modules.attention import (
    MultiHeadedAttention,  # noqa: H301
    RelPositionMultiHeadedAttention,  # noqa: H301
    RelPositionMultiHeadedAttentionChunk,
    LegacyRelPositionMultiHeadedAttention,  # noqa: H301
)
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.modules.embedding import (
    PositionalEncoding,  # noqa: H301
    ScaledPositionalEncoding,  # noqa: H301
@@ -30,7 +29,6 @@
    StreamingRelPositionalEncoding,
)
from funasr.modules.layer_norm import LayerNorm
from funasr.modules.normalization import get_normalization
from funasr.modules.multi_layer_conv import Conv1dLinear
from funasr.modules.multi_layer_conv import MultiLayeredConv1d
from funasr.modules.nets_utils import get_activation
@@ -308,7 +306,7 @@
        feed_forward: torch.nn.Module,
        feed_forward_macaron: torch.nn.Module,
        conv_mod: torch.nn.Module,
        norm_class: torch.nn.Module = torch.nn.LayerNorm,
        norm_class: torch.nn.Module = LayerNorm,
        norm_args: Dict = {},
        dropout_rate: float = 0.0,
    ) -> None:
@@ -534,7 +532,6 @@
            interctc_use_conditioning: bool = False,
            stochastic_depth_rate: Union[float, List[float]] = 0.0,
    ):
        assert check_argument_types()
        super().__init__()
        self._output_size = output_size
@@ -895,7 +892,7 @@
        return x, cache
class ConformerChunkEncoder(torch.nn.Module):
class ConformerChunkEncoder(AbsEncoder):
    """Encoder module definition.
    Args:
        input_size: Input size.
@@ -940,12 +937,10 @@
        default_chunk_size: int = 16,
        jitter_range: int = 4,
        subsampling_factor: int = 1,
        **activation_parameters,
    ) -> None:
        """Construct an Encoder object."""
        super().__init__()
        assert check_argument_types()
        self.embed = StreamingConvInput(
            input_size,
@@ -961,7 +956,7 @@
        )
        activation = get_activation(
            activation_type, **activation_parameters
            activation_type
       )        
        pos_wise_args = (
@@ -991,9 +986,6 @@
            simplified_att_score,
        )
        norm_class, norm_args = get_normalization(
            norm_type,
        )
        fn_modules = []
        for _ in range(num_blocks):
@@ -1003,8 +995,6 @@
                PositionwiseFeedForward(*pos_wise_args),
                PositionwiseFeedForward(*pos_wise_args),
                CausalConvolution(*conv_mod_args),
                norm_class=norm_class,
                norm_args=norm_args,
                dropout_rate=dropout_rate,
            )
            fn_modules.append(module)        
@@ -1012,11 +1002,9 @@
        self.encoders = MultiBlocks(
            [fn() for fn in fn_modules],
            output_size,
            norm_class=norm_class,
            norm_args=norm_args,
        )
        self.output_size = output_size
        self._output_size = output_size
        self.dynamic_chunk_training = dynamic_chunk_training
        self.short_chunk_threshold = short_chunk_threshold
@@ -1028,6 +1016,9 @@
        self.jitter_range = jitter_range
        self.time_reduction_factor = time_reduction_factor
    def output_size(self) -> int:
        return self._output_size
    def get_encoder_input_raw_size(self, size: int, hop_length: int) -> int:
        """Return the corresponding number of sample for a given chunk size, in frames.
@@ -1084,10 +1075,13 @@
                limit_size,
            )
        mask = make_source_mask(x_len)
        mask = make_source_mask(x_len).to(x.device)
        if self.unified_model_training:
            chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item()
            if self.training:
                chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item()
            else:
                chunk_size = self.default_chunk_size
            x, mask = self.embed(x, mask, chunk_size)
            pos_enc = self.pos_enc(x)
            chunk_mask = make_chunk_mask(
@@ -1119,12 +1113,15 @@
        elif self.dynamic_chunk_training:
            max_len = x.size(1)
            chunk_size = torch.randint(1, max_len, (1,)).item()
            if self.training:
                chunk_size = torch.randint(1, max_len, (1,)).item()
            if chunk_size > (max_len * self.short_chunk_threshold):
                chunk_size = max_len
                if chunk_size > (max_len * self.short_chunk_threshold):
                    chunk_size = max_len
                else:
                    chunk_size = (chunk_size % self.short_chunk_size) + 1
            else:
                chunk_size = (chunk_size % self.short_chunk_size) + 1
                chunk_size = self.default_chunk_size
            x, mask = self.embed(x, mask, chunk_size)
            pos_enc = self.pos_enc(x)
@@ -1151,7 +1148,46 @@
            x = x[:,::self.time_reduction_factor,:]
            olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
        return x, olens
        return x, olens, None
    def full_utt_forward(
        self,
        x: torch.Tensor,
        x_len: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Encode input sequences.
        Args:
            x: Encoder input features. (B, T_in, F)
            x_len: Encoder input features lengths. (B,)
        Returns:
           x: Encoder outputs. (B, T_out, D_enc)
           x_len: Encoder outputs lenghts. (B,)
        """
        short_status, limit_size = check_short_utt(
            self.embed.subsampling_factor, x.size(1)
        )
        if short_status:
            raise TooShortUttError(
                f"has {x.size(1)} frames and is too short for subsampling "
                + f"(it needs more than {limit_size} frames), return empty results",
                x.size(1),
                limit_size,
            )
        mask = make_source_mask(x_len).to(x.device)
        x, mask = self.embed(x, mask, None)
        pos_enc = self.pos_enc(x)
        x_utt = self.encoders(
            x,
            pos_enc,
            mask,
            chunk_mask=None,
        )
        if self.time_reduction_factor > 1:
            x_utt = x_utt[:,::self.time_reduction_factor,:]
        return x_utt
    def simu_chunk_forward(
        self,