| | |
| | | from funasr.modules.subsampling import Conv2dSubsampling8 |
| | | from funasr.modules.subsampling import TooShortUttError |
| | | from funasr.modules.subsampling import check_short_utt |
| | | |
| | | from funasr.modules.subsampling import Conv2dSubsamplingPad |
| | | class ConvolutionModule(nn.Module): |
| | | """ConvolutionModule in Conformer model. |
| | | |
| | |
| | | dropout_rate, |
| | | pos_enc_class(output_size, positional_dropout_rate), |
| | | ) |
| | | elif input_layer == "conv2dpad": |
| | | self.embed = Conv2dSubsamplingPad( |
| | | input_size, |
| | | output_size, |
| | | dropout_rate, |
| | | pos_enc_class(output_size, positional_dropout_rate), |
| | | ) |
| | | elif input_layer == "conv2d2": |
| | | self.embed = Conv2dSubsampling2( |
| | | input_size, |
| | |
| | | or isinstance(self.embed, Conv2dSubsampling2) |
| | | or isinstance(self.embed, Conv2dSubsampling6) |
| | | or isinstance(self.embed, Conv2dSubsampling8) |
| | | or isinstance(self.embed, Conv2dSubsamplingPad) |
| | | ): |
| | | short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1)) |
| | | if short_status: |