| File was renamed from funasr/models_transducer/encoder/encoder.py |
| | |
| | | """Encoder for Transducer model.""" |
| | | |
| | | from typing import Any, Dict, List, Tuple |
| | | |
| | | import torch |
| | | from typeguard import check_argument_types |
| | | |
| | | from funasr.models_transducer.encoder.building import ( |
| | | from funasr.models.encoder.chunk_encoder_utils.building import ( |
| | | build_body_blocks, |
| | | build_input_block, |
| | | build_main_parameters, |
| | | build_positional_encoding, |
| | | ) |
| | | from funasr.models_transducer.encoder.validation import validate_architecture |
| | | from funasr.models_transducer.utils import ( |
| | | from funasr.models.encoder.chunk_encoder_utils.validation import validate_architecture |
| | | from funasr.modules.nets_utils import ( |
| | | TooShortUttError, |
| | | check_short_utt, |
| | | make_chunk_mask, |
| | | make_source_mask, |
| | | ) |
| | | |
| | | |
| | | class Encoder(torch.nn.Module): |
| | | class ChunkEncoder(torch.nn.Module): |
| | | """Encoder module definition. |
| | | |
| | | Args: |
| | |
| | | |
| | | self.unified_model_training = main_params["unified_model_training"] |
| | | self.default_chunk_size = main_params["default_chunk_size"] |
| | | self.jitter_range = main_params["jitter_range"] |
| | | self.jitter_range = main_params["jitter_range"] |
| | | |
| | | self.time_reduction_factor = main_params["time_reduction_factor"] |
| | | |
| | | self.time_reduction_factor = main_params["time_reduction_factor"] |
| | | def get_encoder_input_raw_size(self, size: int, hop_length: int) -> int: |
| | | """Return the corresponding number of sample for a given chunk size, in frames. |
| | | |
| | |
| | | |
| | | """ |
| | | return self.embed.get_size_before_subsampling(size) * hop_length |
| | | |
| | | |
| | | def get_encoder_input_size(self, size: int) -> int: |
| | | """Return the corresponding number of sample for a given chunk size, in frames. |
| | | |
| | |
| | | mask, |
| | | chunk_mask=chunk_mask, |
| | | ) |
| | | |
| | | |
| | | olens = mask.eq(0).sum(1) |
| | | if self.time_reduction_factor > 1: |
| | | x_utt = x_utt[:,::self.time_reduction_factor,:] |
| | |
| | | mask, |
| | | chunk_mask=chunk_mask, |
| | | ) |
| | | |
| | | |
| | | olens = mask.eq(0).sum(1) |
| | | if self.time_reduction_factor > 1: |
| | | x = x[:,::self.time_reduction_factor,:] |
| | | olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1 |
| | | |
| | | return x, olens |
| | | |
| | | |
| | | def simu_chunk_forward( |
| | | self, |
| | | x: torch.Tensor, |
| | |
| | | |
| | | if right_context > 0: |
| | | x = x[:, 0:-right_context, :] |
| | | |
| | | |
| | | if self.time_reduction_factor > 1: |
| | | x = x[:,::self.time_reduction_factor,:] |
| | | return x |