lyblsgo
2023-04-21 7bfc4a84fc2d882f34928a033a6d5b60ff72fe19
funasr/models/encoder/conformer_encoder.py
@@ -30,7 +30,6 @@
    StreamingRelPositionalEncoding,
)
from funasr.modules.layer_norm import LayerNorm
from funasr.modules.normalization import get_normalization
from funasr.modules.multi_layer_conv import Conv1dLinear
from funasr.modules.multi_layer_conv import MultiLayeredConv1d
from funasr.modules.nets_utils import get_activation
@@ -895,7 +894,7 @@
        return x, cache
class ConformerChunkEncoder(torch.nn.Module):
class ConformerChunkEncoder(AbsEncoder):
    """Encoder module definition.
    Args:
        input_size: Input size.
@@ -940,7 +939,6 @@
        default_chunk_size: int = 16,
        jitter_range: int = 4,
        subsampling_factor: int = 1,
        **activation_parameters,
    ) -> None:
        """Construct an Encoder object."""
        super().__init__()
@@ -961,7 +959,7 @@
        )
        activation = get_activation(
            activation_type, **activation_parameters
            activation_type
       )        
        pos_wise_args = (
@@ -991,9 +989,6 @@
            simplified_att_score,
        )
        norm_class, norm_args = get_normalization(
            norm_type,
        )
        fn_modules = []
        for _ in range(num_blocks):
@@ -1003,8 +998,6 @@
                PositionwiseFeedForward(*pos_wise_args),
                PositionwiseFeedForward(*pos_wise_args),
                CausalConvolution(*conv_mod_args),
                norm_class=norm_class,
                norm_args=norm_args,
                dropout_rate=dropout_rate,
            )
            fn_modules.append(module)        
@@ -1012,11 +1005,9 @@
        self.encoders = MultiBlocks(
            [fn() for fn in fn_modules],
            output_size,
            norm_class=norm_class,
            norm_args=norm_args,
        )
        self.output_size = output_size
        self._output_size = output_size
        self.dynamic_chunk_training = dynamic_chunk_training
        self.short_chunk_threshold = short_chunk_threshold
@@ -1029,6 +1020,9 @@
        self.time_reduction_factor = time_reduction_factor
    def output_size(self) -> int:
        return self._output_size
    def get_encoder_input_raw_size(self, size: int, hop_length: int) -> int:
        """Return the corresponding number of sample for a given chunk size, in frames.
        Where size is the number of features frames after applying subsampling.