| | |
| | | """Repeat the same layer definition.""" |
| | | |
| | | from typing import Dict, List, Optional |
| | | |
| | | from funasr.modules.layer_norm import LayerNorm |
| | | import torch |
| | | |
| | | |
| | |
| | | self, |
| | | block_list: List[torch.nn.Module], |
| | | output_size: int, |
| | | norm_class: torch.nn.Module = torch.nn.LayerNorm, |
| | | norm_args: Optional[Dict] = None, |
| | | norm_class: torch.nn.Module = LayerNorm, |
| | | ) -> None: |
| | | """Construct a MultiBlocks object.""" |
| | | super().__init__() |
| | | |
| | | self.blocks = torch.nn.ModuleList(block_list) |
| | | self.norm_blocks = norm_class(output_size, **norm_args) |
| | | self.norm_blocks = norm_class(output_size) |
| | | |
| | | self.num_blocks = len(block_list) |
| | | |